"ml/backend/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "3d0b1734c006798960a56acb0ea23ea57e0dd1d9"
Unverified Commit 5362bb8a authored by elk-cloner's avatar elk-cloner Committed by GitHub
Browse files

Tf longformer for sequence classification (#8231)



* working on LongformerForSequenceClassification

* add TFLongformerForMultipleChoice

* add TFLongformerForTokenClassification

* use add_start_docstrings_to_model_forward

* test TFLongformerForSequenceClassification

* test TFLongformerForMultipleChoice

* test TFLongformerForTokenClassification

* remove test from repo

* add test and doc for TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerForMultipleChoice

* add requested classes to modeling_tf_auto.py
update dummy_tf_objects
fix tests
fix bugs in requested classes

* pass all tests except test_inputs_embeds

* sync with master

* pass all tests except test_inputs_embeds

* pass all tests

* pass all tests

* work on test_inputs_embeds

* fix style and quality

* make multi choice work

* fix TFLongformerForTokenClassification signature

* fix TFLongformerForMultipleChoice, TFLongformerForSequenceClassification signature

* fix mult choice

* fix mc hint

* fix input embeds

* fix input embeds

* refactor input embeds

* fix copy issue

* apply sylvains changes and clean more
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 62cd9ce9
......@@ -99,21 +99,41 @@ Longformer specific outputs
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerBaseModelOutputWithPooling
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMaskedLMOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerQuestionAnsweringModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerSequenceClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerMultipleChoiceModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_longformer.LongformerTokenClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutputWithPooling
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMaskedLMOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerQuestionAnsweringModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerSequenceClassifierOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerMultipleChoiceModelOutput
:members:
.. autoclass:: transformers.models.longformer.modeling_tf_longformer.TFLongformerTokenClassifierOutput
:members:
LongformerModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -177,3 +197,24 @@ TFLongformerForQuestionAnswering
.. autoclass:: transformers.TFLongformerForQuestionAnswering
:members: call
TFLongformerForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForSequenceClassification
:members: call
TFLongformerForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForTokenClassification
:members: call
TFLongformerForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFLongformerForMultipleChoice
:members: call
......@@ -766,7 +766,10 @@ if is_tf_available():
from .models.longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerSelfAttention,
)
......
......@@ -92,7 +92,10 @@ from ..funnel.modeling_tf_funnel import (
from ..gpt2.modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
from ..longformer.modeling_tf_longformer import (
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
)
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
......@@ -314,6 +317,7 @@ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(AlbertConfig, TFAlbertForSequenceClassification),
(CamembertConfig, TFCamembertForSequenceClassification),
(XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
(LongformerConfig, TFLongformerForSequenceClassification),
(RobertaConfig, TFRobertaForSequenceClassification),
(BertConfig, TFBertForSequenceClassification),
(XLNetConfig, TFXLNetForSequenceClassification),
......@@ -353,6 +357,7 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(FlaubertConfig, TFFlaubertForTokenClassification),
(XLMConfig, TFXLMForTokenClassification),
(XLMRobertaConfig, TFXLMRobertaForTokenClassification),
(LongformerConfig, TFLongformerForTokenClassification),
(RobertaConfig, TFRobertaForTokenClassification),
(BertConfig, TFBertForTokenClassification),
(MobileBertConfig, TFMobileBertForTokenClassification),
......@@ -368,6 +373,7 @@ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(CamembertConfig, TFCamembertForMultipleChoice),
(XLMConfig, TFXLMForMultipleChoice),
(XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
(LongformerConfig, TFLongformerForMultipleChoice),
(RobertaConfig, TFRobertaForMultipleChoice),
(BertConfig, TFBertForMultipleChoice),
(DistilBertConfig, TFDistilBertForMultipleChoice),
......
......@@ -26,7 +26,10 @@ if is_tf_available():
from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerSelfAttention,
)
......@@ -31,7 +31,6 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
......@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput):
@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
class LongformerMaskedLMOutput(ModelOutput):
"""
Base class for outputs of multiple choice Longformer models.
Base class for masked language models outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Masked language modeling (MLM) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput):
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice Longformer models.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
def _get_question_end_index(input_ids, sep_token_id):
"""
Computes the index of the first occurance of `sep_token_id`.
......@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
return self.lm_head.decoder
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
......@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
return LongformerMaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
......@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=SequenceClassifierOutput,
output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
......@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
return LongformerSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=TokenClassifierOutput,
output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
......@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
return LongformerTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......
......@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def call(self, features):
x = self.dense(features)
x = self.act(x)
x = self.layer_norm(x)
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.layer_norm(hidden_states)
# project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias
hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
return x
return hidden_states
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
......
......@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM:
requires_tf(self)
class TFLongformerForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_tf(self)
......@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering:
requires_tf(self)
class TFLongformerForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForTokenClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
......
......@@ -129,7 +129,7 @@ class LongformerModelTester:
output_without_mask = model(input_ids)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
......@@ -141,7 +141,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_model_with_global_attention_mask(
def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
......@@ -163,7 +163,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_for_masked_lm(
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerForMaskedLM(config=config)
......@@ -172,7 +172,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_longformer_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerForQuestionAnswering(config=config)
......@@ -189,7 +189,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_longformer_for_sequence_classification(
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
......@@ -199,7 +199,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_longformer_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
......@@ -209,7 +209,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_longformer_for_multiple_choice(
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
......@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_longformer_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_attention_mask_determinism(self):
def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model_global_attention_mask(self):
def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self):
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self):
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@require_torch
......@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
) # long input
input_ids = input_ids.to(torch_device)
loss, prediction_scores = model(input_ids, labels=input_ids)
loss, prediction_scores = model(input_ids, labels=input_ids).to_tuple()
expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
......
......@@ -29,7 +29,10 @@ if is_tf_available():
from transformers import (
LongformerConfig,
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerSelfAttention,
)
......@@ -130,7 +133,7 @@ class TFLongformerModelTester:
output_without_mask = model(input_ids)[0]
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
def create_and_check_longformer_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
......@@ -144,7 +147,7 @@ class TFLongformerModelTester:
)
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask(
def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
......@@ -172,7 +175,7 @@ class TFLongformerModelTester:
)
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm(
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
......@@ -180,7 +183,7 @@ class TFLongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_longformer_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
......@@ -196,6 +199,41 @@ class TFLongformerModelTester:
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForSequenceClassification(config=config)
output = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.num_labels])
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForTokenClassification(config=config)
output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.seq_length, self.num_labels])
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFLongformerForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
output = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
).logits
self.parent.assertListEqual(list(output.shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
TFLongformerModel,
TFLongformerForMaskedLM,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForMultipleChoice,
TFLongformerForTokenClassification,
)
if is_tf_available()
else ()
......@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_longformer_model_attention_mask_determinism(self):
def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_global_attention_mask(self):
def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self):
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self):
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@slow
def test_saved_model_with_attentions_output(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment