Unverified Commit a25037be authored by Lorenzo Verardo's avatar Lorenzo Verardo Committed by GitHub
Browse files

MixtralSparseMoeBlock: add gate jitter (#29865)

This commit adds gate jitter to MixtralSparseMoeBlock's input data
before passing it through the MoE layer, if turned on.
parent 75769744
...@@ -92,6 +92,8 @@ class MixtralConfig(PretrainedConfig): ...@@ -92,6 +92,8 @@ class MixtralConfig(PretrainedConfig):
allow the model to output the auxiliary loss. See [here]() for more details allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001): router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss. The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.0):
Amount of noise to add to the router.
```python ```python
>>> from transformers import MixtralModel, MixtralConfig >>> from transformers import MixtralModel, MixtralConfig
...@@ -133,6 +135,7 @@ class MixtralConfig(PretrainedConfig): ...@@ -133,6 +135,7 @@ class MixtralConfig(PretrainedConfig):
num_local_experts=8, num_local_experts=8,
output_router_logits=False, output_router_logits=False,
router_aux_loss_coef=0.001, router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -159,6 +162,7 @@ class MixtralConfig(PretrainedConfig): ...@@ -159,6 +162,7 @@ class MixtralConfig(PretrainedConfig):
self.num_local_experts = num_local_experts self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
......
...@@ -837,9 +837,14 @@ class MixtralSparseMoeBlock(nn.Module): ...@@ -837,9 +837,14 @@ class MixtralSparseMoeBlock(nn.Module):
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
# Jitter parameters
self.jitter_noise = config.router_jitter_noise
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """ """ """
batch_size, sequence_length, hidden_dim = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts) # router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
......
...@@ -42,7 +42,6 @@ if is_torch_available(): ...@@ -42,7 +42,6 @@ if is_torch_available():
class MixtralModelTester: class MixtralModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__
def __init__( def __init__(
self, self,
parent, parent,
...@@ -69,6 +68,7 @@ class MixtralModelTester: ...@@ -69,6 +68,7 @@ class MixtralModelTester:
num_choices=4, num_choices=4,
pad_token_id=0, pad_token_id=0,
scope=None, scope=None,
router_jitter_noise=0.1,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -94,6 +94,7 @@ class MixtralModelTester: ...@@ -94,6 +94,7 @@ class MixtralModelTester:
self.num_choices = num_choices self.num_choices = num_choices
self.pad_token_id = pad_token_id self.pad_token_id = pad_token_id
self.scope = scope self.scope = scope
self.router_jitter_noise = router_jitter_noise
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
...@@ -137,6 +138,7 @@ class MixtralModelTester: ...@@ -137,6 +138,7 @@ class MixtralModelTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
num_experts_per_tok=2, num_experts_per_tok=2,
num_local_experts=2, num_local_experts=2,
router_jitter_noise=self.router_jitter_noise,
) )
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment