Unverified Commit 011cc17a authored by Morgan McGuire's avatar Morgan McGuire Committed by GitHub
Browse files

Fix for non-contiguous label tensors in VisonEncoderDecoder (#21582)

* add prints

* add shape

* add reshape

* clean up
parent 2840272c
......@@ -625,7 +625,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
if not return_dict:
if loss is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment