Unverified Commit 1c52cb7b authored by eigen2017's avatar eigen2017 Committed by GitHub
Browse files

mlp_only_layers is more flexible than decoder_sparse_step (#30552)



* force back to commit ba40a21 and fix workflow errors

* match the review suggestions

* fix ci errors

* fix CI

* fix ci, format code

* fix ci, ruff format

* fix ci, ruff format again

* Update src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* solve this warning: Default Argument Value is mutable

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 73fcfb28
...@@ -91,6 +91,10 @@ class Qwen2MoeConfig(PretrainedConfig): ...@@ -91,6 +91,10 @@ class Qwen2MoeConfig(PretrainedConfig):
allow the model to output the auxiliary loss, including load balancing loss and router z-loss. allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss. The aux loss factor for the total loss.
mlp_only_layers (`List[int]`, *optional*, defaults to `[]`):
Indicate which layers use Qwen2MoeMLP rather than Qwen2MoeSparseMoeBlock
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
```python ```python
>>> from transformers import Qwen2MoeModel, Qwen2MoeConfig >>> from transformers import Qwen2MoeModel, Qwen2MoeConfig
...@@ -135,6 +139,7 @@ class Qwen2MoeConfig(PretrainedConfig): ...@@ -135,6 +139,7 @@ class Qwen2MoeConfig(PretrainedConfig):
norm_topk_prob=False, norm_topk_prob=False,
output_router_logits=False, output_router_logits=False,
router_aux_loss_coef=0.001, router_aux_loss_coef=0.001,
mlp_only_layers=None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -164,6 +169,7 @@ class Qwen2MoeConfig(PretrainedConfig): ...@@ -164,6 +169,7 @@ class Qwen2MoeConfig(PretrainedConfig):
self.norm_topk_prob = norm_topk_prob self.norm_topk_prob = norm_topk_prob
self.output_router_logits = output_router_logits self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers
super().__init__( super().__init__(
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Qwen2MoE model.""" """PyTorch Qwen2MoE model."""
import inspect import inspect
import math import math
import warnings import warnings
...@@ -861,7 +862,9 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -861,7 +862,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.self_attn = QWEN2MOE_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
if config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0: if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen2MoeSparseMoeBlock(config) self.mlp = Qwen2MoeSparseMoeBlock(config)
else: else:
self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size) self.mlp = Qwen2MoeMLP(config, intermediate_size=config.intermediate_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment