"examples/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "6d9f545a67a3c96270c8aa8a8fa1538f329aa44c"
Unverified Commit 181d778f authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`NllbMoe`] Update code to properly support loss computation (#25429)

* update nllb_moe

* fix

* doc nits

* nits

* add a small test

* ficup

* remove adapted from
parent 9264fc91
...@@ -126,7 +126,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l ...@@ -126,7 +126,6 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
return incremental_indices.long() + padding_idx return incremental_indices.long() + padding_idx
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func with SwitchTransformers->NllbMoeModel
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
r""" r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
...@@ -144,6 +143,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T ...@@ -144,6 +143,9 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
Returns: Returns:
The auxiliary loss. The auxiliary loss.
""" """
if router_probs is None:
return 0
num_experts = router_probs.shape[-1] num_experts = router_probs.shape[-1]
# cast the expert indices to int64, otherwise one-hot encoding will fail # cast the expert indices to int64, otherwise one-hot encoding will fail
...@@ -699,7 +701,9 @@ class NllbMoeEncoderLayer(nn.Module): ...@@ -699,7 +701,9 @@ class NllbMoeEncoderLayer(nn.Module):
if self.is_sparse: if self.is_sparse:
hidden_states, router_states = self.ffn(hidden_states, attention_mask) hidden_states, router_states = self.ffn(hidden_states, attention_mask)
else: else:
hidden_states = self.ffn(hidden_states) # router_states set to None to track which layers have None gradients.
hidden_states, router_states = self.ffn(hidden_states), None
hidden_states = self.ff_dropout(hidden_states) hidden_states = self.ff_dropout(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -830,7 +834,8 @@ class NllbMoeDecoderLayer(nn.Module): ...@@ -830,7 +834,8 @@ class NllbMoeDecoderLayer(nn.Module):
if self.is_sparse: if self.is_sparse:
hidden_states, router_states = self.ffn(hidden_states, attention_mask) hidden_states, router_states = self.ffn(hidden_states, attention_mask)
else: else:
hidden_states = self.ffn(hidden_states) hidden_states, router_states = self.ffn(hidden_states), None
hidden_states = self.ff_dropout(hidden_states) hidden_states = self.ff_dropout(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -1730,7 +1735,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): ...@@ -1730,7 +1735,7 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
if output_router_logits: if output_router_logits:
encoder_router_logits = outputs[-1] encoder_router_logits = outputs[-1]
decoder_router_logits = outputs[5 if output_attentions else 3] decoder_router_logits = outputs[3 if output_attentions else 4]
# Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits) encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_router_logits)
...@@ -1775,7 +1780,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): ...@@ -1775,7 +1780,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
decoder_router_logits=outputs.decoder_router_logits, decoder_router_logits=outputs.decoder_router_logits,
) )
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration._unpack_router_logits
def _unpack_router_logits(self, router_outputs): def _unpack_router_logits(self, router_outputs):
total_router_logits = [] total_router_logits = []
total_expert_indexes = [] total_expert_indexes = []
...@@ -1784,11 +1788,10 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): ...@@ -1784,11 +1788,10 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
router_logits, expert_indexes = router_output router_logits, expert_indexes = router_output
total_router_logits.append(router_logits) total_router_logits.append(router_logits)
total_expert_indexes.append(expert_indexes) total_expert_indexes.append(expert_indexes)
if len(total_expert_indexes) > 0:
total_router_logits = torch.cat(total_router_logits, dim=1) total_router_logits = torch.cat(total_router_logits, dim=1) if len(total_router_logits) > 0 else None
if len(total_expert_indexes) > 0: total_expert_indexes = torch.stack(total_expert_indexes, dim=1) if len(total_expert_indexes) > 0 else None
torch.cat(total_expert_indexes, dim=1) return total_router_logits, total_expert_indexes
return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
# Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation # Copied from transfomers.models.switch_transformers.SwitchTransformersForConditionalGeneration.prepare_inputs_for_generation
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
......
...@@ -337,6 +337,16 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -337,6 +337,16 @@ class NllbMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(input_ids, attention_mask=attention_mask) model.generate(input_ids, attention_mask=attention_mask)
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_get_loss(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
input_dict["output_router_logits"] = True
input_dict["labels"] = input_dict["input_ids"]
model = NllbMoeForConditionalGeneration(config).eval().to(torch_device)
out = model(**input_dict)
self.assertIsNotNone(out.loss)
self.assertIsNotNone(model(**input_dict)["encoder_router_logits"][1])
self.assertIsNotNone(model(**input_dict)["decoder_router_logits"][0])
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment