Unverified Commit 84a6570e authored by Youssef Adarrab's avatar Youssef Adarrab Committed by GitHub
Browse files

Make ClipSeg compatible with model parallelism (#22844)

parent 5bb4ec62
...@@ -1480,6 +1480,8 @@ class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel): ...@@ -1480,6 +1480,8 @@ class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to the correct device to enable PP
labels = labels.to(logits.device)
loss_fn = nn.BCEWithLogitsLoss() loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels) loss = loss_fn(logits, labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment