Unverified Commit 870d91fb authored by Shahad Mahmud's avatar Shahad Mahmud Committed by GitHub
Browse files

Model parallelism: Moving labels to the same device as logits for BridgeTower models (#22676)

BrideTower Model parallelism logits device for loss calculation
parent e0921c6b
...@@ -1630,6 +1630,8 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): ...@@ -1630,6 +1630,8 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token loss_fct = CrossEntropyLoss() # -100 index = padding token
labels = labels.to(mlm_logits.device)
masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(mlm_logits.view(-1, self.config.text_config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
...@@ -1730,6 +1732,8 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel): ...@@ -1730,6 +1732,8 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
itm_loss = None itm_loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
labels = labels.to(logits.device)
itm_loss = loss_fct(logits, labels) itm_loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment