Commit 0603564e authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Merge remote-tracking branch 'origin/master'

parents 1e08af38 d86b5ffc
...@@ -82,7 +82,7 @@ class TFDPRContextEncoderOutput(ModelOutput): ...@@ -82,7 +82,7 @@ class TFDPRContextEncoderOutput(ModelOutput):
heads. heads.
""" """
pooler_output: tf.Tensor pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
...@@ -110,7 +110,7 @@ class TFDPRQuestionEncoderOutput(ModelOutput): ...@@ -110,7 +110,7 @@ class TFDPRQuestionEncoderOutput(ModelOutput):
heads. heads.
""" """
pooler_output: tf.Tensor pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
...@@ -141,7 +141,7 @@ class TFDPRReaderOutput(ModelOutput): ...@@ -141,7 +141,7 @@ class TFDPRReaderOutput(ModelOutput):
heads. heads.
""" """
start_logits: tf.Tensor start_logits: tf.Tensor = None
end_logits: tf.Tensor = None end_logits: tf.Tensor = None
relevance_logits: tf.Tensor = None relevance_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None
...@@ -181,7 +181,7 @@ class TFDPREncoder(TFPreTrainedModel): ...@@ -181,7 +181,7 @@ class TFDPREncoder(TFPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict return_dict = return_dict if return_dict is not None else self.bert_model.return_dict
outputs = self.bert_model( outputs = self.bert_model(
inputs=input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
...@@ -228,7 +228,8 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -228,7 +228,8 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
def call( def call(
self, self,
input_ids: Tensor, input_ids: Tensor,
attention_mask: Tensor, attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
...@@ -242,6 +243,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel): ...@@ -242,6 +243,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
outputs = self.encoder( outputs = self.encoder(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -474,19 +476,21 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): ...@@ -474,19 +476,21 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
return_dict = inputs[5] if len(inputs) > 5 else return_dict output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
assert len(inputs) <= 6, "Too many inputs." return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions) output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -573,19 +577,21 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): ...@@ -573,19 +577,21 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
return_dict = inputs[5] if len(inputs) > 5 else return_dict output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
assert len(inputs) <= 6, "Too many inputs." return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions) output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -650,6 +656,7 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -650,6 +656,7 @@ class TFDPRReader(TFDPRPretrainedReader):
self, self,
inputs, inputs,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None, inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = None, output_attentions: bool = None,
output_hidden_states: bool = None, output_hidden_states: bool = None,
...@@ -679,19 +686,21 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -679,19 +686,21 @@ class TFDPRReader(TFDPRPretrainedReader):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
return_dict = inputs[5] if len(inputs) > 5 else return_dict output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
assert len(inputs) <= 6, "Too many inputs." return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions) output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict) return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs." assert len(inputs) <= 7, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -713,9 +722,13 @@ class TFDPRReader(TFDPRPretrainedReader): ...@@ -713,9 +722,13 @@ class TFDPRReader(TFDPRPretrainedReader):
if attention_mask is None: if attention_mask is None:
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32) attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor( return self.span_predictor(
input_ids, input_ids,
attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig): ...@@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig):
""" """
model_type = "gpt2" model_type = "gpt2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -26,7 +26,10 @@ if is_tf_available(): ...@@ -26,7 +26,10 @@ if is_tf_available():
from .modeling_tf_longformer import ( from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM, TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering, TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
...@@ -31,7 +31,6 @@ from ...file_utils import ( ...@@ -31,7 +31,6 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import ( from ...modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
...@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput): ...@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput):
@dataclass @dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput): class LongformerMaskedLMOutput(ModelOutput):
""" """
Base class for outputs of multiple choice Longformer models. Base class for masked language models outputs.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss. Masked language modeling (MLM) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput): ...@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput):
global_attentions: Optional[Tuple[torch.FloatTensor]] = None global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice Longformer models.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
def _get_question_end_index(input_ids, sep_token_id): def _get_question_end_index(input_ids, sep_token_id):
""" """
Computes the index of the first occurance of `sep_token_id`. Computes the index of the first occurance of `sep_token_id`.
...@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
return self.lm_head.decoder return self.lm_head.decoder
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput( return LongformerMaskedLMOutput(
loss=masked_lm_loss, loss=masked_lm_loss,
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
...@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096", checkpoint="allenai/longformer-base-4096",
output_type=SequenceClassifierOutput, output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return LongformerSequenceClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
...@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096", checkpoint="allenai/longformer-base-4096",
output_type=TokenClassifierOutput, output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return LongformerTokenClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
......
...@@ -97,3 +97,4 @@ class MarianConfig(BartConfig): ...@@ -97,3 +97,4 @@ class MarianConfig(BartConfig):
""" """
model_type = "marian" model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
...@@ -102,3 +102,4 @@ class MBartConfig(BartConfig): ...@@ -102,3 +102,4 @@ class MBartConfig(BartConfig):
""" """
model_type = "mbart" model_type = "mbart"
keys_to_ignore_at_inference = ["past_key_values"]
...@@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig): ...@@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
""" """
model_type = "mt5" model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -141,4 +141,5 @@ class PegasusConfig(BartConfig): ...@@ -141,4 +141,5 @@ class PegasusConfig(BartConfig):
""" """
model_type = "pegasus" model_type = "pegasus"
keys_to_ignore_at_inference = ["past_key_values"]
# The implementation of the config object is in BartConfig # The implementation of the config object is in BartConfig
...@@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed. smoothing is performed.
""" """
model_type = "prophetnet" model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig): ...@@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig):
>>> configuration = model.config >>> configuration = model.config
""" """
model_type = "reformer" model_type = "reformer"
keys_to_ignore_at_inference = ["past_buckets_states"]
def __init__( def __init__(
self, self,
......
...@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def call(self, features): def call(self, hidden_states):
x = self.dense(features) hidden_states = self.dense(hidden_states)
x = self.act(x) hidden_states = self.act(hidden_states)
x = self.layer_norm(x) hidden_states = self.layer_norm(hidden_states)
# project back to size of vocabulary with bias # project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
return x return hidden_states
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
......
...@@ -71,6 +71,7 @@ class T5Config(PretrainedConfig): ...@@ -71,6 +71,7 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
""" """
model_type = "t5" model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__( def __init__(
self, self,
......
...@@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig):
""" """
model_type = "transfo-xl" model_type = "transfo-xl"
keys_to_ignore_at_inference = ["mems"]
def __init__( def __init__(
self, self,
......
...@@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig):
""" """
model_type = "xlnet" model_type = "xlnet"
keys_to_ignore_at_inference = ["mems"]
def __init__( def __init__(
self, self,
......
...@@ -470,7 +470,7 @@ class CaptureLogger: ...@@ -470,7 +470,7 @@ class CaptureLogger:
>>> msg = "Testing 1, 2, 3" >>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info() >>> logging.set_verbosity_info()
>>> logger = logging.get_logger("transformers.tokenization_bart") >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
>>> with CaptureLogger(logger) as cl: >>> with CaptureLogger(logger) as cl:
... logger.info(msg) ... logger.info(msg)
>>> assert cl.out, msg+"\n" >>> assert cl.out, msg+"\n"
......
...@@ -1098,10 +1098,11 @@ class Trainer: ...@@ -1098,10 +1098,11 @@ class Trainer:
""" """
outputs = model(**inputs) outputs = model(**inputs)
# Save past state if it exists # Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput. # We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs[0] return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
def is_local_process_zero(self) -> bool: def is_local_process_zero(self) -> bool:
""" """
...@@ -1220,7 +1221,9 @@ class Trainer: ...@@ -1220,7 +1221,9 @@ class Trainer:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
) -> Dict[str, float]:
""" """
Run evaluation and returns metrics. Run evaluation and returns metrics.
...@@ -1234,6 +1237,9 @@ class Trainer: ...@@ -1234,6 +1237,9 @@ class Trainer:
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method. :obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Returns: Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
...@@ -1250,6 +1256,7 @@ class Trainer: ...@@ -1250,6 +1256,7 @@ class Trainer:
# No point gathering the predictions if there are no metrics, otherwise we defer to # No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only # self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None, prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
) )
self.log(output.metrics) self.log(output.metrics)
...@@ -1261,7 +1268,7 @@ class Trainer: ...@@ -1261,7 +1268,7 @@ class Trainer:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput: def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
""" """
Run prediction and returns predictions and potential metrics. Run prediction and returns predictions and potential metrics.
...@@ -1272,6 +1279,9 @@ class Trainer: ...@@ -1272,6 +1279,9 @@ class Trainer:
test_dataset (:obj:`Dataset`): test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
.. note:: .. note::
...@@ -1291,10 +1301,14 @@ class Trainer: ...@@ -1291,10 +1301,14 @@ class Trainer:
test_dataloader = self.get_test_dataloader(test_dataset) test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction") return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
def prediction_loop( def prediction_loop(
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
) -> PredictionOutput: ) -> PredictionOutput:
""" """
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
...@@ -1346,7 +1360,7 @@ class Trainer: ...@@ -1346,7 +1360,7 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader self.callback_handler.eval_dataloader = dataloader
for step, inputs in enumerate(dataloader): for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if loss is not None: if loss is not None:
losses = loss.repeat(batch_size) losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
...@@ -1410,7 +1424,11 @@ class Trainer: ...@@ -1410,7 +1424,11 @@ class Trainer:
return nested_numpify(tensors) return nested_numpify(tensors)
def prediction_step( def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Perform an evaluation step on :obj:`model` using obj:`inputs`. Perform an evaluation step on :obj:`model` using obj:`inputs`.
...@@ -1427,6 +1445,9 @@ class Trainer: ...@@ -1427,6 +1445,9 @@ class Trainer:
argument :obj:`labels`. Check your model's documentation for all accepted arguments. argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`): prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only. Whether or not to return the loss only.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return: Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
...@@ -1434,6 +1455,11 @@ class Trainer: ...@@ -1434,6 +1455,11 @@ class Trainer:
""" """
has_labels = all(inputs.get(k) is not None for k in self.label_names) has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad(): with torch.no_grad():
if self.args.fp16 and _use_native_amp: if self.args.fp16 and _use_native_amp:
...@@ -1442,16 +1468,21 @@ class Trainer: ...@@ -1442,16 +1468,21 @@ class Trainer:
else: else:
outputs = model(**inputs) outputs = model(**inputs)
if has_labels: if has_labels:
if isinstance(outputs, dict):
loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach() loss = outputs[0].mean().detach()
logits = outputs[1:] logits = outputs[1:]
else: else:
loss = None loss = None
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`. if isinstance(outputs, dict):
logits = outputs[:] logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
if prediction_loss_only: if prediction_loss_only:
return (loss, None, None) return (loss, None, None)
......
...@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM: ...@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM:
requires_tf(self) requires_tf(self)
class TFLongformerForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForQuestionAnswering: class TFLongformerForQuestionAnswering:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tf(self) requires_tf(self)
...@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering: ...@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering:
requires_tf(self) requires_tf(self)
class TFLongformerForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForTokenClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerModel: class TFLongformerModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tf(self) requires_tf(self)
......
...@@ -129,7 +129,7 @@ class LongformerModelTester: ...@@ -129,7 +129,7 @@ class LongformerModelTester:
output_without_mask = model(input_ids)["last_hidden_state"] output_without_mask = model(input_ids)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4)) self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model( def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerModel(config=config) model = LongformerModel(config=config)
...@@ -141,7 +141,7 @@ class LongformerModelTester: ...@@ -141,7 +141,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_model_with_global_attention_mask( def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerModel(config=config) model = LongformerModel(config=config)
...@@ -163,7 +163,7 @@ class LongformerModelTester: ...@@ -163,7 +163,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_for_masked_lm( def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerForMaskedLM(config=config) model = LongformerForMaskedLM(config=config)
...@@ -172,7 +172,7 @@ class LongformerModelTester: ...@@ -172,7 +172,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_longformer_for_question_answering( def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = LongformerForQuestionAnswering(config=config) model = LongformerForQuestionAnswering(config=config)
...@@ -189,7 +189,7 @@ class LongformerModelTester: ...@@ -189,7 +189,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_longformer_for_sequence_classification( def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -199,7 +199,7 @@ class LongformerModelTester: ...@@ -199,7 +199,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_longformer_for_token_classification( def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_labels = self.num_labels config.num_labels = self.num_labels
...@@ -209,7 +209,7 @@ class LongformerModelTester: ...@@ -209,7 +209,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_longformer_for_multiple_choice( def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_choices = self.num_choices config.num_choices = self.num_choices
...@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
def test_longformer_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_attention_mask_determinism(self): def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs) self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model_global_attention_mask(self): def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs) self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self): def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self): def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs) self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self): def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self): def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self): def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs) self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@require_torch @require_torch
...@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): ...@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
) # long input ) # long input
input_ids = input_ids.to(torch_device) input_ids = input_ids.to(torch_device)
loss, prediction_scores = model(input_ids, labels=input_ids) loss, prediction_scores = model(input_ids, labels=input_ids).to_tuple()
expected_loss = torch.tensor(0.0074, device=torch_device) expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device) expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
......
...@@ -340,6 +340,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -340,6 +340,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"]) self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
@require_tf
class TFBertModelIntegrationTest(unittest.TestCase): class TFBertModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_inference_masked_lm(self): def test_inference_masked_lm(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment