Unverified Commit a65da83d authored by Zhan Lu's avatar Zhan Lu Committed by GitHub
Browse files

fix:missing `output_router_logits` in SwitchTransformers (#30573)

* fix:missing `output_router_logits` in SwitchTransformers

* fix whitespace in blank line
parent 4ad5adaf
...@@ -1721,6 +1721,8 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod ...@@ -1721,6 +1721,8 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
input_ids = input_ids[:, remove_prefix_length:] input_ids = input_ids[:, remove_prefix_length:]
output_router_logits = kwargs.get("output_router_logits", True)
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
"past_key_values": past_key_values, "past_key_values": past_key_values,
...@@ -1730,6 +1732,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod ...@@ -1730,6 +1732,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, "use_cache": use_cache,
"output_router_logits": output_router_logits,
} }
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment