"src/vscode:/vscode.git/clone" did not exist on "5e323017a4f62c628f1146cb86362ca7d8bf32c4"
Unverified Commit 7b3d4df4 authored by PaulLerner's avatar PaulLerner Committed by GitHub
Browse files

fix: #14486 do not use BertPooler in DPR (#15068)



* fix: #14486 do not use BertPooler in DPR

* fix tf dpr as well

* finish
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 74bec986
......@@ -175,7 +175,7 @@ class DPREncoder(DPRPreTrainedModel):
def __init__(self, config: DPRConfig):
super().__init__(config)
self.bert_model = BertModel(config)
self.bert_model = BertModel(config, add_pooling_layer=False)
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
self.projection_dim = config.projection_dim
if self.projection_dim > 0:
......@@ -202,8 +202,9 @@ class DPREncoder(DPRPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output, pooled_output = outputs[:2]
sequence_output = outputs[0]
pooled_output = sequence_output[:, 0, :]
if self.projection_dim > 0:
pooled_output = self.encode_proj(pooled_output)
......
......@@ -152,7 +152,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
# resolve name conflict with TFBertMainLayer instead of TFBertModel
self.bert_model = TFBertMainLayer(config, name="bert_model")
self.bert_model = TFBertMainLayer(config, add_pooling_layer=False, name="bert_model")
self.config = config
assert self.config.hidden_size > 0, "Encoder hidden_size can't be zero"
......@@ -198,13 +198,13 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
training=inputs["training"],
)
sequence_output, pooled_output = outputs[:2]
sequence_output = outputs[0]
pooled_output = sequence_output[:, 0, :]
if self.projection_dim > 0:
pooled_output = self.encode_proj(pooled_output)
if not inputs["return_dict"]:
return (sequence_output, pooled_output) + outputs[2:]
return (sequence_output, pooled_output) + outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment