Unverified Commit fb3aa06c authored by oscar-garzon's avatar oscar-garzon Committed by GitHub
Browse files

Move labels to the same device as logits for Whisper (#22779)

parent 20e54e49
......@@ -1432,6 +1432,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
if not return_dict:
......@@ -1760,6 +1762,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment