"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "81d6841b4be25a164235975e5ebdcf99d7a26633"
Unverified Commit 24e2fa15 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix encoder-decoder models when labels is passed (#15172)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent e79a0fae
...@@ -529,7 +529,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -529,7 +529,7 @@ class EncoderDecoderModel(PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
warnings.warn(DEPRECATION_WARNING, FutureWarning) warnings.warn(DEPRECATION_WARNING, FutureWarning)
logits = decoder_outputs.logits if return_dict else decoder_outputs[1] logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
......
...@@ -549,7 +549,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -549,7 +549,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
# Compute loss independent from decoder (as some shift the logits inside them) # Compute loss independent from decoder (as some shift the logits inside them)
loss = None loss = None
if labels is not None: if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1] logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
......
...@@ -503,7 +503,7 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -503,7 +503,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
# Compute loss independent from decoder (as some shift the logits inside them) # Compute loss independent from decoder (as some shift the logits inside them)
loss = None loss = None
if labels is not None: if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[1] logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment