"scripts/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "61f42b5732a0740ed9a416a098b96e7e6e14f277"
Unverified Commit 1de8ce9e authored by Shikhar Chauhan's avatar Shikhar Chauhan Committed by GitHub
Browse files

Move labels to the same device as logits for LlamaForSequenceClassification and Blip2 (#22596)

* (feat): Move labels to the same device as logits

* Trigger CI

* Trigger CI

* Trigger CI

* (feat): Making changes for Blip2
parent d59034ff
......@@ -1522,6 +1522,7 @@ class Blip2Model(Blip2PreTrainedModel):
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
......@@ -1757,6 +1758,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
......
......@@ -850,6 +850,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment