Unverified Commit ec54d70e authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF DPR (#9283)

* Fix DPR

* Keep usual models

* Apply style

* Address Sylvain's comments
parent de29ff9b
......@@ -144,18 +144,18 @@ class TFDPRReaderOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
class TFDPREncoder(TFPreTrainedModel):
class TFDPREncoderLayer(tf.keras.layers.Layer):
base_model_prefix = "bert_model"
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
def __init__(self, config: DPRConfig, **kwargs):
super().__init__(**kwargs)
# resolve name conflict with TFBertMainLayer instead of TFBertModel
self.bert_model = TFBertMainLayer(config, name="bert_model")
self.bert_model.config = config
self.config = config
assert self.bert_model.config.hidden_size > 0, "Encoder hidden_size can't be zero"
assert self.config.hidden_size > 0, "Encoder hidden_size can't be zero"
self.projection_dim = config.projection_dim
if self.projection_dim > 0:
self.encode_proj = tf.keras.layers.Dense(
......@@ -220,13 +220,14 @@ class TFDPREncoder(TFPreTrainedModel):
return self.bert_model.config.hidden_size
class TFDPRSpanPredictor(TFPreTrainedModel):
class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
base_model_prefix = "encoder"
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.encoder = TFDPREncoder(config, name="encoder")
def __init__(self, config: DPRConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.encoder = TFDPREncoderLayer(config, name="encoder")
self.qa_outputs = tf.keras.layers.Dense(
2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
......@@ -299,6 +300,97 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
)
class TFDPRSpanPredictor(TFPreTrainedModel):
base_model_prefix = "encoder"
def __init__(self, config: DPRConfig, **kwargs):
super().__init__(config, **kwargs)
self.encoder = TFDPRSpanPredictorLayer(config)
def call(
self,
input_ids: tf.Tensor = None,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
training: bool = False,
**kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
class TFDPREncoder(TFPreTrainedModel):
base_model_prefix = "encoder"
def __init__(self, config: DPRConfig, **kwargs):
super().__init__(config, **kwargs)
self.encoder = TFDPREncoderLayer(config)
def call(
self,
input_ids: tf.Tensor = None,
attention_mask: Optional[tf.Tensor] = None,
token_type_ids: Optional[tf.Tensor] = None,
inputs_embeds: Optional[tf.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = False,
training: bool = False,
**kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
outputs = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
##################
# PreTrainedModel
##################
......@@ -465,8 +557,7 @@ TF_DPR_READER_INPUTS_DOCSTRING = r"""
class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.ctx_encoder = TFDPREncoder(config, name="ctx_encoder")
self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
def get_input_embeddings(self):
return self.ctx_encoder.bert_model.get_input_embeddings()
......@@ -541,6 +632,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
if not inputs["return_dict"]:
return outputs[1:]
return TFDPRContextEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
......@@ -553,8 +645,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.question_encoder = TFDPREncoder(config, name="question_encoder")
self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
def get_input_embeddings(self):
return self.question_encoder.bert_model.get_input_embeddings()
......@@ -641,8 +732,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
class TFDPRReader(TFDPRPretrainedReader):
def __init__(self, config: DPRConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.span_predictor = TFDPRSpanPredictor(config, name="span_predictor")
self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
def get_input_embeddings(self):
return self.span_predictor.encoder.bert_model.get_input_embeddings()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment