Unverified Commit 96d4fa46 authored by Jonatas Grosman's avatar Jonatas Grosman Committed by GitHub
Browse files

[WhisperModel] fix bug in reshaping labels (#21653)

fix bug in reshaping labels
parent fcfd4ec7
...@@ -1211,7 +1211,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1211,7 +1211,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment