"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0a064dc0fcba31092868f911772df087901d90fb"
Unverified Commit e901914d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix for LXMERT (#20986)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8f09dd89
......@@ -739,7 +739,7 @@ class LxmertVisualObjHead(nn.Module):
visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
if config.visual_attr_loss:
visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
if config.visual_obj_loss:
if config.visual_feat_loss:
visual_losses["feat"] = {
"shape": (-1, config.visual_feat_dim),
"num": config.visual_feat_dim,
......@@ -1072,7 +1072,7 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
"num": config.num_attr_labels,
"loss": "visual_ce",
}
if config.visual_obj_loss:
if config.visual_feat_loss:
visual_losses["feat"] = {
"shape": (-1, config.visual_feat_dim),
"num": config.visual_feat_dim,
......
......@@ -1160,7 +1160,7 @@ class TFLxmertVisualObjHead(tf.keras.layers.Layer):
visual_losses["obj"] = {"shape": (-1,), "num": config.num_object_labels}
if config.visual_attr_loss:
visual_losses["attr"] = {"shape": (-1,), "num": config.num_attr_labels}
if config.visual_obj_loss:
if config.visual_feat_loss:
visual_losses["feat"] = {"shape": (-1, 2048), "num": config.visual_feat_dim}
self.visual_losses = visual_losses
......@@ -1228,7 +1228,7 @@ class TFLxmertForPreTraining(TFLxmertPreTrainedModel):
"num": config.num_attr_labels,
"loss": "visn_ce",
}
if config.visual_obj_loss:
if config.visual_feat_loss:
visual_losses["feat"] = {
"shape": (-1, config.visual_feat_dim),
"num": config.visual_feat_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment