Unverified Commit 6e161955 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix #5974 (#5999)

parent e168488a
......@@ -1045,14 +1045,13 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:]
softmax_output = self.crit(pred_hid, labels)
prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
if return_tuple:
output = (prediction_scores,) + outputs[1:]
output = (prediction_scores,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return TransfoXLLMHeadModelOutput(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment