"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "52c85f847aba62ef2018162472007a167dc25622"
Unverified Commit fb3aa06c authored by oscar-garzon's avatar oscar-garzon Committed by GitHub
Browse files

Move labels to the same device as logits for Whisper (#22779)

parent 20e54e49
...@@ -1432,6 +1432,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1432,6 +1432,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
if not return_dict: if not return_dict:
...@@ -1760,6 +1762,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel): ...@@ -1760,6 +1762,8 @@ class WhisperForAudioClassification(WhisperPreTrainedModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment