Unverified Commit e901914d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix for LXMERT (#20986)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8f09dd89
...@@ -739,7 +739,7 @@ class LxmertVisualObjHead(nn.Module): ...@@ -739,7 +739,7 @@ class LxmertVisualObjHead(nn.Module):
visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels} visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
if config.visual_attr_loss: if config.visual_attr_loss:
visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels} visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
if config.visual_obj_loss: if config.visual_feat_loss:
visual_losses["feat"] = { visual_losses["feat"] = {
"shape": (-1, config.visual_feat_dim), "shape": (-1, config.visual_feat_dim),
"num": config.visual_feat_dim, "num": config.visual_feat_dim,
...@@ -1072,7 +1072,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel): ...@@ -1072,7 +1072,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
"num": config.num_attr_labels, "num": config.num_attr_labels,
"loss": "visual_ce", "loss": "visual_ce",
} }
if config.visual_obj_loss: if config.visual_feat_loss:
visual_losses["feat"] = { visual_losses["feat"] = {
"shape": (-1, config.visual_feat_dim), "shape": (-1, config.visual_feat_dim),
"num": config.visual_feat_dim, "num": config.visual_feat_dim,
......
...@@ -1160,7 +1160,7 @@ class TFLxmertVisualObjHead(tf.keras.layers.Layer): ...@@ -1160,7 +1160,7 @@ class TFLxmertVisualObjHead(tf.keras.layers.Layer):
visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels} visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
if config.visual_attr_loss: if config.visual_attr_loss:
visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels} visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
if config.visual_obj_loss: if config.visual_feat_loss:
visual_losses["feat"] = {"shape": (-1, 2048), "num": config.visual_feat_dim} visual_losses["feat"] = {"shape": (-1, 2048), "num": config.visual_feat_dim}
self.visual_losses = visual_losses self.visual_losses = visual_losses
...@@ -1228,7 +1228,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel): ...@@ -1228,7 +1228,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
"num": config.num_attr_labels, "num": config.num_attr_labels,
"loss": "visn_ce", "loss": "visn_ce",
} }
if config.visual_obj_loss: if config.visual_feat_loss:
visual_losses["feat"] = { visual_losses["feat"] = {
"shape": (-1, config.visual_feat_dim), "shape": (-1, config.visual_feat_dim),
"num": config.visual_feat_dim, "num": config.visual_feat_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment