Unverified Commit de8b06f9 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[SpeechEncoderDecoderModel] Fix bug in reshaping labels (#16748)

parent 048443db
...@@ -557,7 +557,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -557,7 +557,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
if labels is not None: if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
if not return_dict: if not return_dict:
if loss is not None: if loss is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment