Unverified Commit 1de8ce9e authored by Shikhar Chauhan's avatar Shikhar Chauhan Committed by GitHub
Browse files

Move labels to the same device as logits for LlamaForSequenceClassification and Blip2 (#22596)

* (feat): Move labels to the same device as logits

* Trigger CI

* Trigger CI

* Trigger CI

* (feat): Making changes for Blip2
parent d59034ff
......@@ -1522,6 +1522,7 @@ class Blip2Model(Blip2PreTrainedModel):
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
......@@ -1757,6 +1758,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
......
......@@ -850,6 +850,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment