Unverified Commit a65da83d authored by Zhan Lu's avatar Zhan Lu Committed by GitHub
Browse files

fix:missing `output_router_logits` in SwitchTransformers (#30573)

* fix:missing `output_router_logits` in SwitchTransformers

* fix whitespace in blank line
parent 4ad5adaf
......@@ -1721,6 +1721,8 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
input_ids = input_ids[:, remove_prefix_length:]
output_router_logits = kwargs.get("output_router_logits", True)
return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
......@@ -1730,6 +1732,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
"output_router_logits": output_router_logits,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment