Unverified Commit abd503d9 authored by Rahul's avatar Rahul Committed by GitHub
Browse files

TF - Adding Unpack Decorator For DPR model (#16212)

* Adding Unpack Decorator

* Adding Unpack Decorator-moved it on top
parent d9b8d1a9
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" TensorFlow DPR model for Open Domain Question Answering.""" """ TensorFlow DPR model for Open Domain Question Answering."""
from dataclasses import dataclass from dataclasses import dataclass
...@@ -26,7 +27,7 @@ from ...file_utils import ( ...@@ -26,7 +27,7 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutputWithPooling from ...modeling_tf_outputs import TFBaseModelOutputWithPooling
from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, input_processing, shape_list from ...modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list, unpack_inputs
from ...utils import logging from ...utils import logging
from ..bert.modeling_tf_bert import TFBertMainLayer from ..bert.modeling_tf_bert import TFBertMainLayer
from .configuration_dpr import DPRConfig from .configuration_dpr import DPRConfig
...@@ -162,6 +163,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer): ...@@ -162,6 +163,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj" config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj"
) )
@unpack_inputs
def call( def call(
self, self,
input_ids: tf.Tensor = None, input_ids: tf.Tensor = None,
...@@ -174,9 +176,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer): ...@@ -174,9 +176,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]: ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
inputs = input_processing( outputs = self.bert_model(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
...@@ -185,17 +185,6 @@ class TFDPREncoderLayer(tf.keras.layers.Layer): ...@@ -185,17 +185,6 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.bert_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
token_type_ids=inputs["token_type_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -203,7 +192,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer): ...@@ -203,7 +192,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
if self.projection_dim > 0: if self.projection_dim > 0:
pooled_output = self.encode_proj(pooled_output) pooled_output = self.encode_proj(pooled_output)
if not inputs["return_dict"]: if not return_dict:
return (sequence_output, pooled_output) + outputs[1:] return (sequence_output, pooled_output) + outputs[1:]
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPooling(
...@@ -236,6 +225,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer): ...@@ -236,6 +225,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier" 1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier"
) )
@unpack_inputs
def call( def call(
self, self,
input_ids: tf.Tensor = None, input_ids: tf.Tensor = None,
...@@ -250,10 +240,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer): ...@@ -250,10 +240,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2] n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2]
# feed encoder # feed encoder
outputs = self.encoder(
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -261,16 +248,6 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer): ...@@ -261,16 +248,6 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -286,7 +263,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer): ...@@ -286,7 +263,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
end_logits = tf.reshape(end_logits, [n_passages, sequence_length]) end_logits = tf.reshape(end_logits, [n_passages, sequence_length])
relevance_logits = tf.reshape(relevance_logits, [n_passages]) relevance_logits = tf.reshape(relevance_logits, [n_passages])
if not inputs["return_dict"]: if not return_dict:
return (start_logits, end_logits, relevance_logits) + outputs[2:] return (start_logits, end_logits, relevance_logits) + outputs[2:]
return TFDPRReaderOutput( return TFDPRReaderOutput(
...@@ -306,6 +283,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -306,6 +283,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.encoder = TFDPRSpanPredictorLayer(config) self.encoder = TFDPRSpanPredictorLayer(config)
@unpack_inputs
def call( def call(
self, self,
input_ids: tf.Tensor = None, input_ids: tf.Tensor = None,
...@@ -318,27 +296,14 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -318,27 +296,14 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]: ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
inputs = input_processing( outputs = self.encoder(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -352,6 +317,7 @@ class TFDPREncoder(TFPreTrainedModel): ...@@ -352,6 +317,7 @@ class TFDPREncoder(TFPreTrainedModel):
self.encoder = TFDPREncoderLayer(config) self.encoder = TFDPREncoderLayer(config)
@unpack_inputs
def call( def call(
self, self,
input_ids: tf.Tensor = None, input_ids: tf.Tensor = None,
...@@ -364,27 +330,14 @@ class TFDPREncoder(TFPreTrainedModel): ...@@ -364,27 +330,14 @@ class TFDPREncoder(TFPreTrainedModel):
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]: ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
inputs = input_processing( outputs = self.encoder(
func=self.call,
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.encoder(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -594,6 +547,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -594,6 +547,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
self(self.dummy_inputs) self(self.dummy_inputs)
return self.ctx_encoder.bert_model.get_input_embeddings() return self.ctx_encoder.bert_model.get_input_embeddings()
@unpack_inputs
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -622,50 +576,36 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -622,50 +576,36 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
>>> embeddings = model(input_ids).pooler_output >>> embeddings = model(input_ids).pooler_output
``` ```
""" """
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = ( attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32) tf.ones(input_shape, dtype=tf.dtypes.int32)
if inputs["input_ids"] is None if input_ids is None
else (inputs["input_ids"] != self.config.pad_token_id) else (input_ids != self.config.pad_token_id)
) )
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32) token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.ctx_encoder( outputs = self.ctx_encoder(
input_ids=inputs["input_ids"], input_ids=input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return outputs[1:] return outputs[1:]
return TFDPRContextEncoderOutput( return TFDPRContextEncoderOutput(
...@@ -695,6 +635,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -695,6 +635,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
self(self.dummy_inputs) self(self.dummy_inputs)
return self.question_encoder.bert_model.get_input_embeddings() return self.question_encoder.bert_model.get_input_embeddings()
@unpack_inputs
@add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -723,50 +664,36 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -723,50 +664,36 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
>>> embeddings = model(input_ids).pooler_output >>> embeddings = model(input_ids).pooler_output
``` ```
""" """
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = ( attention_mask = (
tf.ones(input_shape, dtype=tf.dtypes.int32) tf.ones(input_shape, dtype=tf.dtypes.int32)
if inputs["input_ids"] is None if input_ids is None
else (inputs["input_ids"] != self.config.pad_token_id) else (input_ids != self.config.pad_token_id)
) )
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.zeros(input_shape, dtype=tf.dtypes.int32) token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
outputs = self.question_encoder( outputs = self.question_encoder(
input_ids=inputs["input_ids"], input_ids=input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
if not inputs["return_dict"]: if not return_dict:
return outputs[1:] return outputs[1:]
return TFDPRQuestionEncoderOutput( return TFDPRQuestionEncoderOutput(
pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
...@@ -795,6 +722,7 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -795,6 +722,7 @@ class TFDPRReader(TFDPRPretrainedReader):
self(self.dummy_inputs) self(self.dummy_inputs)
return self.span_predictor.encoder.bert_model.get_input_embeddings() return self.span_predictor.encoder.bert_model.get_input_embeddings()
@unpack_inputs
@add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
...@@ -830,9 +758,19 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -830,9 +758,19 @@ class TFDPRReader(TFDPRPretrainedReader):
>>> relevance_logits = outputs.relevance_logits >>> relevance_logits = outputs.relevance_logits
``` ```
""" """
inputs = input_processing( if input_ids is not None and inputs_embeds is not None:
func=self.call, raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
config=self.config, elif input_ids is not None:
input_shape = shape_list(input_ids)
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if attention_mask is None:
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -840,29 +778,6 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -840,29 +778,6 @@ class TFDPRReader(TFDPRPretrainedReader):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None:
input_shape = shape_list(inputs["input_ids"])
elif inputs["inputs_embeds"] is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None:
inputs["attention_mask"] = tf.ones(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
def serving_output(self, output): def serving_output(self, output):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment