Unverified Commit df32b5d8 authored by Johnny Greco's avatar Johnny Greco Committed by GitHub
Browse files

TFLongformer: Add missing type hints and unpack inputs decorator (#16228)



* Add type annotations for TF Longformer

* Update docstring data types to include numpy array

* Implement unpack_inputs decorator

* fixup after decorator updates

* Numpy array -> np.ndarray in docstring
Co-authored-by: default avatarJohnny Greco <johnny.greco@radpartners.com>
parent 0aac9ba2
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import numpy as np
import tensorflow as tf import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
...@@ -30,14 +31,15 @@ from ...file_utils import ( ...@@ -30,14 +31,15 @@ from ...file_utils import (
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFModelInputType,
TFMultipleChoiceLoss, TFMultipleChoiceLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss, TFSequenceClassificationLoss,
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing,
keras_serializable, keras_serializable,
unpack_inputs,
) )
from ...tf_utils import shape_list from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
...@@ -1660,6 +1662,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1660,6 +1662,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
""" """
raise NotImplementedError raise NotImplementedError
@unpack_inputs
def call( def call(
self, self,
input_ids=None, input_ids=None,
...@@ -1675,63 +1678,45 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1675,63 +1678,45 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
training=False, training=False,
**kwargs, **kwargs,
): ):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif inputs["input_ids"] is not None: elif input_ids is not None:
input_shape = shape_list(inputs["input_ids"]) input_shape = shape_list(input_ids)
elif inputs["inputs_embeds"] is not None: elif inputs_embeds is not None:
input_shape = shape_list(inputs["inputs_embeds"])[:-1] input_shape = shape_list(inputs_embeds)[:-1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs["attention_mask"] is None: if attention_mask is None:
inputs["attention_mask"] = tf.fill(input_shape, 1) attention_mask = tf.fill(input_shape, 1)
if inputs["token_type_ids"] is None: if token_type_ids is None:
inputs["token_type_ids"] = tf.fill(input_shape, 0) token_type_ids = tf.fill(input_shape, 0)
# merge `global_attention_mask` and `attention_mask` # merge `global_attention_mask` and `attention_mask`
if inputs["global_attention_mask"] is not None: if global_attention_mask is not None:
inputs["attention_mask"] = self._merge_to_attention_mask( attention_mask = self._merge_to_attention_mask(attention_mask, global_attention_mask)
inputs["attention_mask"], inputs["global_attention_mask"]
)
( (
padding_len, padding_len,
inputs["input_ids"], input_ids,
inputs["attention_mask"], attention_mask,
inputs["token_type_ids"], token_type_ids,
inputs["position_ids"], position_ids,
inputs["inputs_embeds"], inputs_embeds,
) = self._pad_to_window_size( ) = self._pad_to_window_size(
input_ids=inputs["input_ids"], input_ids=input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
position_ids=inputs["position_ids"], position_ids=position_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
# is index masked or global attention # is index masked or global attention
is_index_masked = tf.math.less(inputs["attention_mask"], 1) is_index_masked = tf.math.less(attention_mask, 1)
is_index_global_attn = tf.math.greater(inputs["attention_mask"], 1) is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn) is_global_attn = tf.math.reduce_any(is_index_global_attn)
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
...@@ -1739,10 +1724,8 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1739,10 +1724,8 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask_shape = shape_list(inputs["attention_mask"]) attention_mask_shape = shape_list(attention_mask)
extended_attention_mask = tf.reshape( extended_attention_mask = tf.reshape(attention_mask, (attention_mask_shape[0], attention_mask_shape[1], 1, 1))
inputs["attention_mask"], (attention_mask_shape[0], attention_mask_shape[1], 1, 1)
)
# Since attention_mask is 1.0 for positions we want to attend locally and 0.0 for # Since attention_mask is 1.0 for positions we want to attend locally and 0.0 for
# masked and global attn positions, this operation will create a tensor which is 0.0 for # masked and global attn positions, this operation will create a tensor which is 0.0 for
...@@ -1751,11 +1734,11 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1751,11 +1734,11 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0 extended_attention_mask = tf.cast(tf.math.abs(1 - extended_attention_mask), tf.dtypes.float32) * -10000.0
embedding_output = self.embeddings( embedding_output = self.embeddings(
inputs["input_ids"], input_ids,
inputs["position_ids"], position_ids,
inputs["token_type_ids"], token_type_ids,
inputs["inputs_embeds"], inputs_embeds,
training=inputs["training"], training=training,
) )
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -1765,15 +1748,15 @@ class TFLongformerMainLayer(tf.keras.layers.Layer): ...@@ -1765,15 +1748,15 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
is_index_masked=is_index_masked, is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn, is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn, is_global_attn=is_global_attn,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]: if not return_dict:
return ( return (
sequence_output, sequence_output,
pooled_output, pooled_output,
...@@ -1934,27 +1917,27 @@ LONGFORMER_START_DOCSTRING = r""" ...@@ -1934,27 +1917,27 @@ LONGFORMER_START_DOCSTRING = r"""
LONGFORMER_INPUTS_DOCSTRING = r""" LONGFORMER_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (`tf.Tensor` of shape `({0})`): input_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`LongformerTokenizer`]. See [`PreTrainedTokenizer.__call__`] and Indices can be obtained using [`LongformerTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
[`PreTrainedTokenizer.encode`] for details. [`PreTrainedTokenizer.encode`] for details.
[What are input IDs?](../glossary#input-ids) [What are input IDs?](../glossary#input-ids)
attention_mask (`tf.Tensor` of shape `({0})`, *optional*): attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): head_mask (`np.ndarray` or `tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**, - 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**. - 0 indicates the head is **masked**.
global_attention_mask (`tf.Tensor` of shape `({0})`, *optional*): global_attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Mask to decide the attention given on each token, local attention or global attention. Tokens with global Mask to decide the attention given on each token, local attention or global attention. Tokens with global
attention attends to all other tokens, and all other tokens attend to them. This is important for attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example, task-specific finetuning because it makes the model more flexible at representing the task. For example,
...@@ -1965,7 +1948,7 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ...@@ -1965,7 +1948,7 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
- 0 for local attention (a sliding window attention), - 0 for local attention (a sliding window attention),
- 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - 1 for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
token_type_ids (`tf.Tensor` of shape `({0})`, *optional*): token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
1]`: 1]`:
...@@ -1973,12 +1956,12 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ...@@ -1973,12 +1956,12 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
- 1 corresponds to a *sentence B* token. - 1 corresponds to a *sentence B* token.
[What are token type IDs?](../glossary#token-type-ids) [What are token type IDs?](../glossary#token-type-ids)
position_ids (`tf.Tensor` of shape `({0})`, *optional*): position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`. config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix. model's internal embedding lookup matrix.
...@@ -2025,25 +2008,25 @@ class TFLongformerModel(TFLongformerPreTrainedModel): ...@@ -2025,25 +2008,25 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
self.longformer = TFLongformerMainLayer(config, name="longformer") self.longformer = TFLongformerMainLayer(config, name="longformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
global_attention_mask=None, global_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFLongformerBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call, outputs = self.longformer(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -2055,20 +2038,6 @@ class TFLongformerModel(TFLongformerPreTrainedModel): ...@@ -2055,20 +2038,6 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
return outputs return outputs
...@@ -2108,6 +2077,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2108,6 +2077,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2118,29 +2088,28 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2118,29 +2088,28 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
global_attention_mask=None, global_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFLongformerMaskedLMOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
""" """
inputs = input_processing(
func=self.call, outputs = self.longformer(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -2151,28 +2120,13 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -2151,28 +2120,13 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output, training=inputs["training"]) prediction_scores = self.lm_head(sequence_output, training=training)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], prediction_scores) loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
if not inputs["return_dict"]: if not return_dict:
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2217,6 +2171,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2217,6 +2171,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
name="qa_outputs", name="qa_outputs",
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2226,21 +2181,21 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2226,21 +2181,21 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
global_attention_mask=None, global_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
start_positions=None, start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
end_positions=None, end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFLongformerQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
r""" r"""
start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
...@@ -2251,9 +2206,22 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2251,9 +2206,22 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
""" """
inputs = input_processing(
func=self.call, # set global attention on question tokens
config=self.config, if global_attention_mask is None and input_ids is not None:
if shape_list(tf.where(input_ids == self.config.sep_token_id))[0] != 3 * shape_list(input_ids)[0]:
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass."
)
global_attention_mask = tf.fill(shape_list(input_ids), value=0)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(input_ids == self.config.sep_token_id)
sep_token_indices = tf.cast(sep_token_indices, dtype=input_ids.dtype)
global_attention_mask = _compute_global_attention_mask(shape_list(input_ids), sep_token_indices)
outputs = self.longformer(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -2264,43 +2232,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2264,43 +2232,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
start_positions=start_positions,
end_positions=end_positions,
training=training, training=training,
kwargs_call=kwargs,
)
# set global attention on question tokens
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None:
if (
shape_list(tf.where(inputs["input_ids"] == self.config.sep_token_id))[0]
!= 3 * shape_list(inputs["input_ids"])[0]
):
logger.warning(
f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this. This is most likely an error. The global attention is disabled for this forward pass."
)
inputs["global_attention_mask"] = tf.fill(shape_list(inputs["input_ids"]), value=0)
else:
logger.info("Initializing global attention on question tokens...")
# put global attention on all tokens until `config.sep_token_id` is reached
sep_token_indices = tf.where(inputs["input_ids"] == self.config.sep_token_id)
sep_token_indices = tf.cast(sep_token_indices, dtype=inputs["input_ids"].dtype)
inputs["global_attention_mask"] = _compute_global_attention_mask(
shape_list(inputs["input_ids"]), sep_token_indices
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
...@@ -2309,12 +2241,12 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -2309,12 +2241,12 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
end_logits = tf.squeeze(end_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1)
loss = None loss = None
if inputs["start_positions"] is not None and inputs["end_positions"] is not None: if start_positions is not None and end_positions is not None:
labels = {"start_position": inputs["start_positions"]} labels = {"start_position": start_positions}
labels["end_position"] = inputs["end_positions"] labels["end_position"] = end_positions
loss = self.hf_compute_loss(labels, (start_logits, end_logits)) loss = self.hf_compute_loss(labels, (start_logits, end_logits))
if not inputs["return_dict"]: if not return_dict:
output = (start_logits, end_logits) + outputs[2:] output = (start_logits, end_logits) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2386,6 +2318,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2386,6 +2318,7 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer")
self.classifier = TFLongformerClassificationHead(config, name="classifier") self.classifier = TFLongformerClassificationHead(config, name="classifier")
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2395,73 +2328,56 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2395,73 +2328,56 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
global_attention_mask=None, global_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFLongformerSequenceClassifierOutput, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["global_attention_mask"] is None and inputs["input_ids"] is not None: if global_attention_mask is None and input_ids is not None:
logger.info("Initializing global attention on CLS token...") logger.info("Initializing global attention on CLS token...")
# global attention on cls token # global attention on cls token
inputs["global_attention_mask"] = tf.zeros_like(inputs["input_ids"]) global_attention_mask = tf.zeros_like(input_ids)
updates = tf.ones(shape_list(inputs["input_ids"])[0], dtype=tf.int32) updates = tf.ones(shape_list(input_ids)[0], dtype=tf.int32)
indices = tf.pad( indices = tf.pad(
tensor=tf.expand_dims(tf.range(shape_list(inputs["input_ids"])[0]), axis=1), tensor=tf.expand_dims(tf.range(shape_list(input_ids)[0]), axis=1),
paddings=[[0, 0], [0, 1]], paddings=[[0, 0], [0, 1]],
constant_values=0, constant_values=0,
) )
inputs["global_attention_mask"] = tf.tensor_scatter_nd_update( global_attention_mask = tf.tensor_scatter_nd_update(
inputs["global_attention_mask"], global_attention_mask,
indices, indices,
updates, updates,
) )
outputs = self.longformer( outputs = self.longformer(
input_ids=inputs["input_ids"], input_ids=input_ids,
attention_mask=inputs["attention_mask"], attention_mask=attention_mask,
head_mask=inputs["head_mask"], head_mask=head_mask,
global_attention_mask=inputs["global_attention_mask"], global_attention_mask=global_attention_mask,
token_type_ids=inputs["token_type_ids"], token_type_ids=token_type_ids,
position_ids=inputs["position_ids"], position_ids=position_ids,
inputs_embeds=inputs["inputs_embeds"], inputs_embeds=inputs_embeds,
output_attentions=inputs["output_attentions"], output_attentions=output_attentions,
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2510,6 +2426,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2510,6 +2426,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
global_attention_mask = tf.convert_to_tensor([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2) global_attention_mask = tf.convert_to_tensor([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
return {"input_ids": input_ids, "global_attention_mask": global_attention_mask} return {"input_ids": input_ids, "global_attention_mask": global_attention_mask}
@unpack_inputs
@add_start_docstrings_to_model_forward( @add_start_docstrings_to_model_forward(
LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
) )
...@@ -2521,68 +2438,45 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2521,68 +2438,45 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
global_attention_mask=None, global_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFLongformerMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
""" """
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if inputs["input_ids"] is not None: if input_ids is not None:
num_choices = shape_list(inputs["input_ids"])[1] num_choices = shape_list(input_ids)[1]
seq_length = shape_list(inputs["input_ids"])[2] seq_length = shape_list(input_ids)[2]
else: else:
num_choices = shape_list(inputs["inputs_embeds"])[1] num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs["inputs_embeds"])[2] seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = ( flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
) flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_token_type_ids = (
tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None
)
flat_position_ids = (
tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None
)
flat_global_attention_mask = ( flat_global_attention_mask = (
tf.reshape(inputs["global_attention_mask"], (-1, shape_list(inputs["global_attention_mask"])[-1])) tf.reshape(global_attention_mask, (-1, shape_list(global_attention_mask)[-1]))
if inputs["global_attention_mask"] is not None if global_attention_mask is not None
else None else None
) )
flat_inputs_embeds = ( flat_inputs_embeds = (
tf.reshape(inputs["inputs_embeds"], (-1, seq_length, shape_list(inputs["inputs_embeds"])[3])) tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs["inputs_embeds"] is not None if inputs_embeds is not None
else None else None
) )
...@@ -2596,8 +2490,8 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2596,8 +2490,8 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
inputs_embeds=flat_inputs_embeds, inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=inputs["return_dict"], return_dict=return_dict,
training=inputs["training"], training=training,
) )
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -2605,9 +2499,9 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ...@@ -2605,9 +2499,9 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices)) reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], reshaped_logits) loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
if not inputs["return_dict"]: if not return_dict:
output = (reshaped_logits,) + outputs[2:] output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -2664,6 +2558,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2664,6 +2558,7 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@unpack_inputs
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -2673,27 +2568,26 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2673,27 +2568,26 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
) )
def call( def call(
self, self,
input_ids=None, input_ids: Optional[TFModelInputType] = None,
attention_mask=None, attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask=None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
token_type_ids=None, token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
position_ids=None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
global_attention_mask=None, global_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
labels=None, labels: Optional[Union[np.array, tf.Tensor]] = None,
training=False, training: Optional[bool] = False,
**kwargs, **kwargs,
): ) -> Union[TFLongformerTokenClassifierOutput, Tuple[tf.Tensor]]:
r""" r"""
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
""" """
inputs = input_processing(
func=self.call, outputs = self.longformer(
config=self.config,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -2704,29 +2598,14 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2704,29 +2598,14 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
training=training, training=training,
kwargs_call=kwargs,
)
outputs = self.longformer(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
head_mask=inputs["head_mask"],
global_attention_mask=inputs["global_attention_mask"],
token_type_ids=inputs["token_type_ids"],
position_ids=inputs["position_ids"],
inputs_embeds=inputs["inputs_embeds"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
) )
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not inputs["return_dict"]: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment