Unverified Commit 5362bb8a authored by elk-cloner's avatar elk-cloner Committed by GitHub
Browse files

Tf longformer for sequence classification (#8231)



* working on LongformerForSequenceClassification

* add TFLongformerForMultipleChoice

* add TFLongformerForTokenClassification

* use add_start_docstrings_to_model_forward

* test TFLongformerForSequenceClassification

* test TFLongformerForMultipleChoice

* test TFLongformerForTokenClassification

* remove test from repo

* add test and doc for TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerForMultipleChoice

* add requested classes to modeling_tf_auto.py
update dummy_tf_objects
fix tests
fix bugs in requested classes

* pass all tests except test_inputs_embeds

* sync with master

* pass all tests except test_inputs_embeds

* pass all tests

* pass all tests

* work on test_inputs_embeds

* fix style and quality

* make multi choice work

* fix TFLongformerForTokenClassification signature

* fix TFLongformerForMultipleChoice, TFLongformerForSequenceClassification signature

* fix mult choice

* fix mc hint

* fix input embeds

* fix input embeds

* refactor input embeds

* fix copy issue

* apply sylvains changes and clean more
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 62cd9ce9
...@@ -99,21 +99,41 @@ Longformer specific outputs ...@@ -99,21 +99,41 @@ Longformer specific outputs
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerBaseModelOutputWithPooling .. autoclass:: transformers.models.longformer.modeling_longformer.LongformerBaseModelOutputWithPooling
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput .. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerQuestionAnsweringModelOutput .. autoclass:: transformers.models.longformer.modeling_longformer.LongformerQuestionAnsweringModelOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerSequenceClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerTokenClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput .. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutputWithPooling .. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutputWithPooling
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMaskedLMOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerQuestionAnsweringModelOutput .. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerQuestionAnsweringModelOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerSequenceClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMultipleChoiceModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerTokenClassifierOutput
:members:
LongformerModel LongformerModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -177,3 +197,24 @@ TFLongformerForQuestionAnswering ...@@ -177,3 +197,24 @@ TFLongformerForQuestionAnswering
.. autoclass:: transformers.TFLongformerForQuestionAnswering .. autoclass:: transformers.TFLongformerForQuestionAnswering
:members: call :members: call
TFLongformerForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForSequenceClassification
:members: call
TFLongformerForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForTokenClassification
:members: call
TFLongformerForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForMultipleChoice
:members: call
...@@ -766,7 +766,10 @@ if is_tf_available(): ...@@ -766,7 +766,10 @@ if is_tf_available():
from .models.longformer import ( from .models.longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
......
...@@ -92,7 +92,10 @@ from ..funnel.modeling_tf_funnel import ( ...@@ -92,7 +92,10 @@ from ..funnel.modeling_tf_funnel import (
from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from ..longformer.modeling_tf_longformer import ( from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
) )
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
...@@ -314,6 +317,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -314,6 +317,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForSequenceClassification), (AlbertConfig, TFAlbertForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification), (CamembertConfig, TFCamembertForSequenceClassification),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification), (XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(LongformerConfig, TFLongformerForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification), (RobertaConfig, TFRobertaForSequenceClassification),
(BertConfig, TFBertForSequenceClassification), (BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification), (XLNetConfig, TFXLNetForSequenceClassification),
...@@ -353,6 +357,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -353,6 +357,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(FlaubertConfig, TFFlaubertForTokenClassification), (FlaubertConfig, TFFlaubertForTokenClassification),
(XLMConfig, TFXLMForTokenClassification), (XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification), (XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(LongformerConfig, TFLongformerForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification), (RobertaConfig, TFRobertaForTokenClassification),
(BertConfig, TFBertForTokenClassification), (BertConfig, TFBertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification), (MobileBertConfig, TFMobileBertForTokenClassification),
...@@ -368,6 +373,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( ...@@ -368,6 +373,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(CamembertConfig, TFCamembertForMultipleChoice), (CamembertConfig, TFCamembertForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice), (XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice), (XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(LongformerConfig, TFLongformerForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice), (RobertaConfig, TFRobertaForMultipleChoice),
(BertConfig, TFBertForMultipleChoice), (BertConfig, TFBertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice), (DistilBertConfig, TFDistilBertForMultipleChoice),
......
...@@ -26,7 +26,10 @@ if is_tf_available(): ...@@ -26,7 +26,10 @@ if is_tf_available():
from .modeling_tf_longformer import ( from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
...@@ -31,7 +31,6 @@ from ...file_utils import ( ...@@ -31,7 +31,6 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
...@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput): ...@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput):
@dataclass @dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput): class LongformerMaskedLMOutput(ModelOutput):
""" """
Base class for outputs of multiple choice Longformer models. Base class for masked language models outputs.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss. Masked language modeling (MLM) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput): ...@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput):
global_attentions: Optional[Tuple[torch.FloatTensor]] = None global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice Longformer models.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
def _get_question_end_index(input_ids, sep_token_id): def _get_question_end_index(input_ids, sep_token_id):
""" """
Computes the index of the first occurance of `sep_token_id`. Computes the index of the first occurance of `sep_token_id`.
...@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
return self.lm_head.decoder return self.lm_head.decoder
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput( return LongformerMaskedLMOutput(
loss=masked_lm_loss, loss=masked_lm_loss,
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
...@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096", checkpoint="allenai/longformer-base-4096",
output_type=SequenceClassifierOutput, output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return LongformerSequenceClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
...@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096", checkpoint="allenai/longformer-base-4096",
output_type=TokenClassifierOutput, output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return LongformerTokenClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
......
...@@ -19,19 +19,21 @@ from typing import Optional, Tuple ...@@ -19,19 +19,21 @@ from typing import Optional, Tuple
import tensorflow as tf import tensorflow as tf
from transformers.activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
MULTIPLE_CHOICE_DUMMY_INPUTS,
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
) )
from ...modeling_tf_outputs import TFMaskedLMOutput, TFQuestionAnsweringModelOutput
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFMaskedLanguageModelingLoss, TFMaskedLanguageModelingLoss,
TFMultipleChoiceLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFQuestionAnsweringLoss, TFQuestionAnsweringLoss,
TFSequenceClassificationLoss,
TFTokenClassificationLoss,
get_initializer, get_initializer,
keras_serializable, keras_serializable,
shape_list, shape_list,
...@@ -147,6 +149,52 @@ class TFLongformerBaseModelOutputWithPooling(ModelOutput): ...@@ -147,6 +149,52 @@ class TFLongformerBaseModelOutputWithPooling(ModelOutput):
global_attentions: Optional[Tuple[tf.Tensor]] = None global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Masked language modeling (MLM) loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass @dataclass
class TFLongformerQuestionAnsweringModelOutput(ModelOutput): class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
""" """
...@@ -196,6 +244,146 @@ class TFLongformerQuestionAnsweringModelOutput(ModelOutput): ...@@ -196,6 +244,146 @@ class TFLongformerQuestionAnsweringModelOutput(ModelOutput):
global_attentions: Optional[Tuple[tf.Tensor]] = None global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
Args:
loss (:obj:`tf.Tensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFLongformerTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x +
attention_window + 1)`, where ``x`` is the number of tokens with global attention mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, x)`,
where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
global_attentions: Optional[Tuple[tf.Tensor]] = None
def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True): def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_sep_token=True):
""" """
Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is Computes global attention mask by putting attention on all tokens before `sep_token_id` if `before_sep_token is
...@@ -249,18 +437,17 @@ class TFLongformerLMHead(tf.keras.layers.Layer): ...@@ -249,18 +437,17 @@ class TFLongformerLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def call(self, features): def call(self, hidden_states):
x = self.dense(features) hidden_states = self.dense(hidden_states)
x = self.act(x) hidden_states = self.act(hidden_states)
x = self.layer_norm(x) hidden_states = self.layer_norm(hidden_states)
# project back to size of vocabulary with bias # project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
return x return hidden_states
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings
class TFLongformerEmbeddings(tf.keras.layers.Layer): class TFLongformerEmbeddings(tf.keras.layers.Layer):
""" """
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
...@@ -304,17 +491,23 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -304,17 +491,23 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def create_position_ids_from_input_ids(self, x): def create_position_ids_from_input_ids(self, input_ids):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
Args: Args:
x: tf.Tensor input_ids: tf.Tensor
Returns: tf.Tensor Returns: tf.Tensor
""" """
mask = tf.cast(tf.math.not_equal(x, self.padding_idx), dtype=tf.int32) input_ids_shape = shape_list(input_ids)
# multiple choice has 3 dimensions
if len(input_ids_shape) == 3:
input_ids = tf.reshape(input_ids, (input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2]))
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=tf.int32)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx return incremental_indices + self.padding_idx
...@@ -1783,7 +1976,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -1783,7 +1976,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096", checkpoint="allenai/longformer-base-4096",
output_type=TFMaskedLMOutput, output_type=TFLongformerMaskedLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -1837,11 +2030,12 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ...@@ -1837,11 +2030,12 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput( return TFLongformerMaskedLMOutput(
loss=loss, loss=loss,
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
) )
...@@ -1871,7 +2065,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -1871,7 +2065,7 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-large-4096-finetuned-triviaqa", checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
output_type=TFQuestionAnsweringModelOutput, output_type=TFLongformerQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def call( def call(
...@@ -1969,3 +2163,357 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn ...@@ -1969,3 +2163,357 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
attentions=outputs.attentions, attentions=outputs.attentions,
global_attentions=outputs.global_attentions, global_attentions=outputs.global_attentions,
) )
class TFLongformerClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.out_proj = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
)
def call(self, hidden_states, training=False):
hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
output = self.out_proj(hidden_states)
return output
@add_start_docstrings(
"""
Longformer Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config, name="longformer")
self.classifier = TFLongformerClassificationHead(config, name="classifier")
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=TFLongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
global_attention_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
labels = inputs.get("labels", labels)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
if global_attention_mask is None and input_ids is not None:
logger.info("Initializing global attention on CLS token...")
# global attention on cls token
global_attention_mask = tf.zeros_like(input_ids)
global_attention_mask = tf.tensor_scatter_nd_update(
global_attention_mask,
[[i, 0] for i in range(input_ids.shape[0])],
[1 for _ in range(input_ids.shape[0])],
)
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFLongformerSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
@add_start_docstrings(
"""
Longformer Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
a softmax) e.g. for RocStories/SWAG tasks.
""",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoiceLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.longformer = TFLongformerMainLayer(config, name="longformer")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@property
def dummy_inputs(self):
input_ids = tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)
# make sure global layers are initialized
global_attention_mask = tf.constant([[[0, 0, 0, 1], [0, 0, 0, 1]]] * 2)
return {"input_ids": input_ids, "global_attention_mask": global_attention_mask}
@add_start_docstrings_to_model_forward(
LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=TFLongformerMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
global_attention_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
:obj:`input_ids` above)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
global_attention_mask = inputs[4] if len(inputs) > 4 else global_attention_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
labels = inputs[9] if len(inputs) > 9 else labels
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
global_attention_mask = inputs.get("global_attention_mask", global_attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
labels = inputs.get("labels", labels)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
return_dict = return_dict if return_dict is not None else self.config.return_dict
if input_ids is not None:
num_choices = shape_list(input_ids)[1]
seq_length = shape_list(input_ids)[2]
else:
num_choices = shape_list(inputs_embeds)[1]
seq_length = shape_list(inputs_embeds)[2]
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
flat_global_attention_mask = (
tf.reshape(global_attention_mask, (-1, global_attention_mask.shape[-1]))
if global_attention_mask is not None
else None
)
flat_inputs_embeds = (
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
if inputs_embeds is not None
else None
)
outputs = self.longformer(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
global_attention_mask=flat_global_attention_mask,
inputs_embeds=flat_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = tf.reshape(logits, (-1, num_choices))
loss = None if labels is None else self.compute_loss(labels, reshaped_logits)
if not return_dict:
output = (reshaped_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFLongformerMultipleChoiceModelOutput(
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
@add_start_docstrings(
"""
Longformer Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
""",
LONGFORMER_START_DOCSTRING,
)
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
authorized_missing_keys = [r"pooler"]
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.longformer = TFLongformerMainLayer(config=config, name="longformer")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=TFLongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
global_attention_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.return_dict
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.longformer(
inputs,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFLongformerTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
global_attentions=outputs.global_attentions,
)
...@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def call(self, features): def call(self, hidden_states):
x = self.dense(features) hidden_states = self.dense(hidden_states)
x = self.act(x) hidden_states = self.act(hidden_states)
x = self.layer_norm(x) hidden_states = self.layer_norm(hidden_states)
# project back to size of vocabulary with bias # project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
return x return hidden_states
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
......
...@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM: ...@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM:
requires_tf(self) requires_tf(self)
class TFLongformerForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForQuestionAnswering: class TFLongformerForQuestionAnswering:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tf(self) requires_tf(self)
...@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering: ...@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering:
requires_tf(self) requires_tf(self)
class TFLongformerForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForTokenClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerModel: class TFLongformerModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tf(self) requires_tf(self)
......
...@@ -129,7 +129,7 @@ class LongformerModelTester: ...@@ -129,7 +129,7 @@ class LongformerModelTester:
output_without_mask = model(input_ids)["last_hidden_state"] output_without_mask = model(input_ids)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4)) self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerModel(config=config) model = LongformerModel(config=config)
...@@ -141,7 +141,7 @@ class LongformerModelTester: ...@@ -141,7 +141,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerModel(config=config) model = LongformerModel(config=config)
...@@ -163,7 +163,7 @@ class LongformerModelTester: ...@@ -163,7 +163,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_for_masked_lm( def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerForMaskedLM(config=config) model = LongformerForMaskedLM(config=config)
...@@ -172,7 +172,7 @@ class LongformerModelTester: ...@@ -172,7 +172,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_longformer_for_question_answering( def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerForQuestionAnswering(config=config) model = LongformerForQuestionAnswering(config=config)
...@@ -189,7 +189,7 @@ class LongformerModelTester: ...@@ -189,7 +189,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_longformer_for_sequence_classification( def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -199,7 +199,7 @@ class LongformerModelTester: ...@@ -199,7 +199,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_longformer_for_token_classification( def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -209,7 +209,7 @@ class LongformerModelTester: ...@@ -209,7 +209,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_longformer_for_multiple_choice( def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_choices = self.num_choices config.num_choices = self.num_choices
...@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_longformer_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_attention_mask_determinism(self): def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model_global_attention_mask(self): def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self): def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self): def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self): def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@require_torch @require_torch
...@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
) # long input ) # long input
input_ids = input_ids.to(torch_device) input_ids = input_ids.to(torch_device)
loss, prediction_scores = model(input_ids, labels=input_ids) loss, prediction_scores = model(input_ids, labels=input_ids).to_tuple()
expected_loss = torch.tensor(0.0074, device=torch_device) expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device) expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
......
...@@ -29,7 +29,10 @@ if is_tf_available(): ...@@ -29,7 +29,10 @@ if is_tf_available():
from transformers import ( from transformers import (
LongformerConfig, LongformerConfig,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
...@@ -130,7 +133,7 @@ class TFLongformerModelTester: ...@@ -130,7 +133,7 @@ class TFLongformerModelTester:
output_without_mask = model(input_ids)[0] output_without_mask = model(input_ids)[0]
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4) tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
def create_and_check_longformer_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -144,7 +147,7 @@ class TFLongformerModelTester: ...@@ -144,7 +147,7 @@ class TFLongformerModelTester:
) )
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -172,7 +175,7 @@ class TFLongformerModelTester: ...@@ -172,7 +175,7 @@ class TFLongformerModelTester:
) )
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm( def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -180,7 +183,7 @@ class TFLongformerModelTester: ...@@ -180,7 +183,7 @@ class TFLongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_longformer_for_question_answering( def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -196,6 +199,41 @@ class TFLongformerModelTester: ...@@ -196,6 +199,41 @@ class TFLongformerModelTester:
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length]) self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length]) self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForSequenceClassification(config=config)
output = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.num_labels])
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForTokenClassification(config=config)
output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.seq_length, self.num_labels])
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFLongformerForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
output = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
).logits
self.parent.assertListEqual(list(output.shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
TFLongformerModel, TFLongformerModel,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForMultipleChoice,
TFLongformerForTokenClassification,
) )
if is_tf_available() if is_tf_available()
else () else ()
...@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_longformer_model_attention_mask_determinism(self): def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_global_attention_mask(self): def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@slow @slow
def test_saved_model_with_attentions_output(self): def test_saved_model_with_attentions_output(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment