Unverified Commit 5362bb8a authored by elk-cloner's avatar elk-cloner Committed by GitHub
Browse files

Tf longformer for sequence classification (#8231)



* working on LongformerForSequenceClassification

* add TFLongformerForMultipleChoice

* add TFLongformerForTokenClassification

* use add_start_docstrings_to_model_forward

* test TFLongformerForSequenceClassification

* test TFLongformerForMultipleChoice

* test TFLongformerForTokenClassification

* remove test from repo

* add test and doc for TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerForMultipleChoice

* add requested classes to modeling_tf_auto.py
update dummy_tf_objects
fix tests
fix bugs in requested classes

* pass all tests except test_inputs_embeds

* sync with master

* pass all tests except test_inputs_embeds

* pass all tests

* pass all tests

* work on test_inputs_embeds

* fix style and quality

* make multi choice work

* fix TFLongformerForTokenClassification signature

* fix TFLongformerForMultipleChoice, TFLongformerForSequenceClassification signature

* fix mult choice

* fix mc hint

* fix input embeds

* fix input embeds

* refactor input embeds

* fix copy issue

* apply sylvains changes and clean more
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 62cd9ce9
...@@ -99,21 +99,41 @@ Longformer specific outputs ...@@ -99,21 +99,41 @@ Longformer specific outputs
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerBaseModelOutputWithPooling .. autoclass:: transformers.models.longformer.modeling_longformer.LongformerBaseModelOutputWithPooling
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput .. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerQuestionAnsweringModelOutput .. autoclass:: transformers.models.longformer.modeling_longformer.LongformerQuestionAnsweringModelOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerSequenceClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerTokenClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput .. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutputWithPooling .. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutputWithPooling
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMaskedLMOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerQuestionAnsweringModelOutput .. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerQuestionAnsweringModelOutput
:members: :members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerSequenceClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMultipleChoiceModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerTokenClassifierOutput
:members:
LongformerModel LongformerModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -177,3 +197,24 @@ TFLongformerForQuestionAnswering ...@@ -177,3 +197,24 @@ TFLongformerForQuestionAnswering
.. autoclass:: transformers.TFLongformerForQuestionAnswering .. autoclass:: transformers.TFLongformerForQuestionAnswering
:members: call :members: call
TFLongformerForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForSequenceClassification
:members: call
TFLongformerForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForTokenClassification
:members: call
TFLongformerForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForMultipleChoice
:members: call
...@@ -766,7 +766,10 @@ if is_tf_available(): ...@@ -766,7 +766,10 @@ if is_tf_available():
from .models.longformer import ( from .models.longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
......
...@@ -92,7 +92,10 @@ from ..funnel.modeling_tf_funnel import ( ...@@ -92,7 +92,10 @@ from ..funnel.modeling_tf_funnel import (
from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from ..longformer.modeling_tf_longformer import ( from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
) )
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
...@@ -314,6 +317,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -314,6 +317,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForSequenceClassification), (AlbertConfig, TFAlbertForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification), (CamembertConfig, TFCamembertForSequenceClassification),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification), (XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(LongformerConfig, TFLongformerForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification), (RobertaConfig, TFRobertaForSequenceClassification),
(BertConfig, TFBertForSequenceClassification), (BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification), (XLNetConfig, TFXLNetForSequenceClassification),
...@@ -353,6 +357,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -353,6 +357,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(FlaubertConfig, TFFlaubertForTokenClassification), (FlaubertConfig, TFFlaubertForTokenClassification),
(XLMConfig, TFXLMForTokenClassification), (XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification), (XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(LongformerConfig, TFLongformerForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification), (RobertaConfig, TFRobertaForTokenClassification),
(BertConfig, TFBertForTokenClassification), (BertConfig, TFBertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification), (MobileBertConfig, TFMobileBertForTokenClassification),
...@@ -368,6 +373,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( ...@@ -368,6 +373,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(CamembertConfig, TFCamembertForMultipleChoice), (CamembertConfig, TFCamembertForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice), (XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice), (XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(LongformerConfig, TFLongformerForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice), (RobertaConfig, TFRobertaForMultipleChoice),
(BertConfig, TFBertForMultipleChoice), (BertConfig, TFBertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice), (DistilBertConfig, TFDistilBertForMultipleChoice),
......
...@@ -26,7 +26,10 @@ if is_tf_available(): ...@@ -26,7 +26,10 @@ if is_tf_available():
from .modeling_tf_longformer import ( from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
...@@ -31,7 +31,6 @@ from ...file_utils import ( ...@@ -31,7 +31,6 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
...@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput): ...@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput):
@dataclass @dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput): class LongformerMaskedLMOutput(ModelOutput):
""" """
Base class for outputs of multiple choice Longformer models. Base class for masked language models outputs.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss. Masked language modeling (MLM) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput): ...@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput):
global_attentions: Optional[Tuple[torch.FloatTensor]] = None global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice Longformer models.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
def _get_question_end_index(input_ids, sep_token_id): def _get_question_end_index(input_ids, sep_token_id):
""" """
Computes the index of the first occurance of `sep_token_id`. Computes the index of the first occurance of `sep_token_id`.
...@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
return self.lm_head.decoder return self.lm_head.decoder
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput( return LongformerMaskedLMOutput(
loss=masked_lm_loss, loss=masked_lm_loss,
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
...@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096", checkpoint="allenai/longformer-base-4096",
output_type=SequenceClassifierOutput, output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return LongformerSequenceClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
...@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096", checkpoint="allenai/longformer-base-4096",
output_type=TokenClassifierOutput, output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return LongformerTokenClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
......
...@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def call(self, features): def call(self, hidden_states):
x = self.dense(features) hidden_states = self.dense(hidden_states)
x = self.act(x) hidden_states = self.act(hidden_states)
x = self.layer_norm(x) hidden_states = self.layer_norm(hidden_states)
# project back to size of vocabulary with bias # project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
return x return hidden_states
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
......
...@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM: ...@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM:
requires_tf(self) requires_tf(self)
class TFLongformerForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForQuestionAnswering: class TFLongformerForQuestionAnswering:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tf(self) requires_tf(self)
...@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering: ...@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering:
requires_tf(self) requires_tf(self)
class TFLongformerForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForTokenClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerModel: class TFLongformerModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tf(self) requires_tf(self)
......
...@@ -129,7 +129,7 @@ class LongformerModelTester: ...@@ -129,7 +129,7 @@ class LongformerModelTester:
output_without_mask = model(input_ids)["last_hidden_state"] output_without_mask = model(input_ids)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4)) self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerModel(config=config) model = LongformerModel(config=config)
...@@ -141,7 +141,7 @@ class LongformerModelTester: ...@@ -141,7 +141,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerModel(config=config) model = LongformerModel(config=config)
...@@ -163,7 +163,7 @@ class LongformerModelTester: ...@@ -163,7 +163,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_for_masked_lm( def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerForMaskedLM(config=config) model = LongformerForMaskedLM(config=config)
...@@ -172,7 +172,7 @@ class LongformerModelTester: ...@@ -172,7 +172,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_longformer_for_question_answering( def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerForQuestionAnswering(config=config) model = LongformerForQuestionAnswering(config=config)
...@@ -189,7 +189,7 @@ class LongformerModelTester: ...@@ -189,7 +189,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_longformer_for_sequence_classification( def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -199,7 +199,7 @@ class LongformerModelTester: ...@@ -199,7 +199,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_longformer_for_token_classification( def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -209,7 +209,7 @@ class LongformerModelTester: ...@@ -209,7 +209,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_longformer_for_multiple_choice( def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_choices = self.num_choices config.num_choices = self.num_choices
...@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_longformer_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_attention_mask_determinism(self): def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model_global_attention_mask(self): def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self): def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self): def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self): def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@require_torch @require_torch
...@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
) # long input ) # long input
input_ids = input_ids.to(torch_device) input_ids = input_ids.to(torch_device)
loss, prediction_scores = model(input_ids, labels=input_ids) loss, prediction_scores = model(input_ids, labels=input_ids).to_tuple()
expected_loss = torch.tensor(0.0074, device=torch_device) expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device) expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
......
...@@ -29,7 +29,10 @@ if is_tf_available(): ...@@ -29,7 +29,10 @@ if is_tf_available():
from transformers import ( from transformers import (
LongformerConfig, LongformerConfig,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
...@@ -130,7 +133,7 @@ class TFLongformerModelTester: ...@@ -130,7 +133,7 @@ class TFLongformerModelTester:
output_without_mask = model(input_ids)[0] output_without_mask = model(input_ids)[0]
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4) tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
def create_and_check_longformer_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -144,7 +147,7 @@ class TFLongformerModelTester: ...@@ -144,7 +147,7 @@ class TFLongformerModelTester:
) )
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -172,7 +175,7 @@ class TFLongformerModelTester: ...@@ -172,7 +175,7 @@ class TFLongformerModelTester:
) )
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm( def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -180,7 +183,7 @@ class TFLongformerModelTester: ...@@ -180,7 +183,7 @@ class TFLongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size]) self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_longformer_for_question_answering( def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.return_dict = True config.return_dict = True
...@@ -196,6 +199,41 @@ class TFLongformerModelTester: ...@@ -196,6 +199,41 @@ class TFLongformerModelTester:
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length]) self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length]) self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForSequenceClassification(config=config)
output = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.num_labels])
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForTokenClassification(config=config)
output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.seq_length, self.num_labels])
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFLongformerForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
output = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
).logits
self.parent.assertListEqual(list(output.shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
TFLongformerModel, TFLongformerModel,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForMultipleChoice,
TFLongformerForTokenClassification,
) )
if is_tf_available() if is_tf_available()
else () else ()
...@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_longformer_model_attention_mask_determinism(self): def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_global_attention_mask(self): def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@slow @slow
def test_saved_model_with_attentions_output(self): def test_saved_model_with_attentions_output(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment