Unverified Commit 2aef4cfe authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix TFTransfoXLLMHeadModel outputs (#16590)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8d57c424
...@@ -998,12 +998,13 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -998,12 +998,13 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
softmax_output = self.crit(pred_hid, labels, training=training) softmax_output = self.crit(pred_hid, labels, training=training)
prediction_scores = softmax_output if labels is None else ()
if not return_dict: if not return_dict:
return (softmax_output,) + transformer_outputs[1:] return (prediction_scores,) + transformer_outputs[1:]
return TFTransfoXLLMHeadModelOutput( return TFTransfoXLLMHeadModelOutput(
prediction_scores=softmax_output, prediction_scores=prediction_scores,
mems=transformer_outputs.mems, mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment