Unverified Commit 1de8ce9e authored by Shikhar Chauhan's avatar Shikhar Chauhan Committed by GitHub
Browse files

Move labels to the same device as logits for LlamaForSequenceClassification and Blip2 (#22596)

* (feat): Move labels to the same device as logits

* Trigger CI

* Trigger CI

* Trigger CI

* (feat): Making changes for Blip2
parent d59034ff
...@@ -1522,6 +1522,7 @@ class Blip2Model(Blip2PreTrainedModel): ...@@ -1522,6 +1522,7 @@ class Blip2Model(Blip2PreTrainedModel):
loss = None loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds # we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :] logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous() shift_logits = logits[..., :-1, :].contiguous()
...@@ -1757,6 +1758,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): ...@@ -1757,6 +1758,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
loss = None loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds # we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :] logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous() shift_logits = logits[..., :-1, :].contiguous()
......
...@@ -850,6 +850,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): ...@@ -850,6 +850,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None: if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression" self.config.problem_type = "regression"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment