Unverified Commit b6865b9b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix model parallelism for `BridgeTower` (#23039)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d337631b
...@@ -981,7 +981,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel): ...@@ -981,7 +981,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
config_class = BridgeTowerConfig config_class = BridgeTowerConfig
base_model_prefix = "bridgetower" base_model_prefix = "bridgetower"
supports_gradient_checkpointing = False supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention"] _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, BridgeTowerVisionModel): if isinstance(module, BridgeTowerVisionModel):
...@@ -1863,12 +1863,16 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel): ...@@ -1863,12 +1863,16 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
# normalized features # normalized features
text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2) text_embeds = nn.functional.normalize(self.itc_text_head(text_embeds[:, 0, :]), dim=-1, p=2)
image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2) image_embeds = nn.functional.normalize(self.itc_image_head(image_embeds[:, 0, :]), dim=-1, p=2).to(
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2) device=text_embeds.device
)
cross_embeds = nn.functional.normalize(self.itc_cross_modal_head(pooler_output), dim=-1, p=2).to(
device=text_embeds.device
)
logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2) logits = torch.stack([text_embeds, image_embeds, cross_embeds], dim=-2)
logit_scale = self.logit_scale.exp() logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_text_to_image = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale logits_text_to_cross = torch.matmul(text_embeds, cross_embeds.t()) * logit_scale
logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale logits_image_to_cross = torch.matmul(image_embeds, cross_embeds.t()) * logit_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment