Commit 0603564e authored by Sylvain Gugger's avatar Sylvain Gugger
Browse files

Merge remote-tracking branch 'origin/master'

parents 1e08af38 d86b5ffc
......@@ -82,7 +82,7 @@ class TFDPRContextEncoderOutput(ModelOutput):
heads.
"""
pooler_output: tf.Tensor
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
......@@ -110,7 +110,7 @@ class TFDPRQuestionEncoderOutput(ModelOutput):
heads.
"""
pooler_output: tf.Tensor
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
......@@ -141,7 +141,7 @@ class TFDPRReaderOutput(ModelOutput):
heads.
"""
start_logits: tf.Tensor
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
relevance_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
......@@ -181,7 +181,7 @@ class TFDPREncoder(TFPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.bert_model.return_dict
outputs = self.bert_model(
inputs=input_ids,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
......@@ -228,7 +228,8 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
def call(
self,
input_ids: Tensor,
attention_mask: Tensor,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
......@@ -242,6 +243,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -474,19 +476,21 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
......@@ -573,19 +577,21 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
......@@ -650,6 +656,7 @@ class TFDPRReader(TFDPRPretrainedReader):
self,
inputs,
attention_mask: Optional[Tensor] = None,
token_type_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
......@@ -679,19 +686,21 @@ class TFDPRReader(TFDPRPretrainedReader):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
inputs_embeds = inputs[2] if len(inputs) > 2 else inputs_embeds
output_attentions = inputs[3] if len(inputs) > 3 else output_attentions
output_hidden_states = inputs[4] if len(inputs) > 4 else output_hidden_states
return_dict = inputs[5] if len(inputs) > 5 else return_dict
assert len(inputs) <= 6, "Too many inputs."
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
output_attentions = inputs[4] if len(inputs) > 4 else output_attentions
output_hidden_states = inputs[5] if len(inputs) > 5 else output_hidden_states
return_dict = inputs[6] if len(inputs) > 6 else return_dict
assert len(inputs) <= 7, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 6, "Too many inputs."
assert len(inputs) <= 7, "Too many inputs."
else:
input_ids = inputs
......@@ -713,9 +722,13 @@ class TFDPRReader(TFDPRPretrainedReader):
if attention_mask is None:
attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
if token_type_ids is None:
token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
return self.span_predictor(
input_ids,
attention_mask,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......
......@@ -120,6 +120,7 @@ class GPT2Config(PretrainedConfig):
"""
model_type = "gpt2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
......
......@@ -26,7 +26,10 @@ if is_tf_available():
from .modeling_tf_longformer import (
TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerSelfAttention,
)
......@@ -31,7 +31,6 @@ from ...file_utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
......@@ -151,17 +150,15 @@ class LongformerBaseModelOutputWithPooling(ModelOutput):
@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
class LongformerMaskedLMOutput(ModelOutput):
"""
Base class for outputs of multiple choice Longformer models.
Base class for masked language models outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Masked language modeling (MLM) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -249,6 +246,149 @@ class LongformerQuestionAnsweringModelOutput(ModelOutput):
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice Longformer models.
Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class LongformerTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x + attention_window + 1)`, where ``x`` is the number of tokens with global attention
mask.
Local attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token in the sequence to every token with
global attention (first ``x`` values) and to every token in the attention window (remaining
``attention_window + 1`` values). Note that the first ``x`` values refer to tokens with fixed positions in
the text, but the remaining ``attention_window + 1`` values refer to tokens with relative positions: the
attention weight of a token to itself is located at index ``x + attention_window / 2`` and the
``attention_window / 2`` preceding (succeeding) values are the attention weights to the ``attention_window
/ 2`` preceding (succeeding) tokens. If the attention window contains a token with global attention, the
attention weight at the corresponding index is set to 0; the value should be accessed from the first ``x``
attention weights. If a token has global attention, the attention weights to all other tokens in
:obj:`attentions` is set to 0, the values should be accessed from :obj:`global_attentions`.
global_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, x)`, where ``x`` is the number of tokens with global attention mask.
Global attentions weights after the attention softmax, used to compute the weighted average in the
self-attention heads. Those are the attention weights from every token with global attention to every token
in the sequence.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
global_attentions: Optional[Tuple[torch.FloatTensor]] = None
def _get_question_end_index(input_ids, sep_token_id):
"""
Computes the index of the first occurance of `sep_token_id`.
......@@ -1495,7 +1635,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
return self.lm_head.decoder
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=LongformerMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
......@@ -1561,7 +1701,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
output = (prediction_scores,) + outputs[2:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
return MaskedLMOutput(
return LongformerMaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
......@@ -1593,7 +1733,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=SequenceClassifierOutput,
output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
......@@ -1651,7 +1791,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
return LongformerSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......@@ -1837,7 +1977,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="allenai/longformer-base-4096",
output_type=TokenClassifierOutput,
output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
......@@ -1895,7 +2035,7 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
return LongformerTokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......
......@@ -97,3 +97,4 @@ class MarianConfig(BartConfig):
"""
model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
......@@ -102,3 +102,4 @@ class MBartConfig(BartConfig):
"""
model_type = "mbart"
keys_to_ignore_at_inference = ["past_key_values"]
......@@ -62,6 +62,7 @@ class MT5Config(PretrainedConfig):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
"""
model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
......
......@@ -141,4 +141,5 @@ class PegasusConfig(BartConfig):
"""
model_type = "pegasus"
keys_to_ignore_at_inference = ["past_key_values"]
# The implementation of the config object is in BartConfig
......@@ -92,6 +92,7 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed.
"""
model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
......
......@@ -153,6 +153,7 @@ class ReformerConfig(PretrainedConfig):
>>> configuration = model.config
"""
model_type = "reformer"
keys_to_ignore_at_inference = ["past_buckets_states"]
def __init__(
self,
......
......@@ -751,15 +751,15 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
super().build(input_shape)
def call(self, features):
x = self.dense(features)
x = self.act(x)
x = self.layer_norm(x)
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.layer_norm(hidden_states)
# project back to size of vocabulary with bias
x = self.decoder(x, mode="linear") + self.bias
hidden_states = self.decoder(hidden_states, mode="linear") + self.bias
return x
return hidden_states
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
......
......@@ -71,6 +71,7 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
"""
model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
......
......@@ -105,6 +105,7 @@ class TransfoXLConfig(PretrainedConfig):
"""
model_type = "transfo-xl"
keys_to_ignore_at_inference = ["mems"]
def __init__(
self,
......
......@@ -128,6 +128,7 @@ class XLNetConfig(PretrainedConfig):
"""
model_type = "xlnet"
keys_to_ignore_at_inference = ["mems"]
def __init__(
self,
......
......@@ -470,7 +470,7 @@ class CaptureLogger:
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("transformers.tokenization_bart")
>>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg+"\n"
......
......@@ -1098,10 +1098,11 @@ class Trainer:
"""
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs[0]
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
def is_local_process_zero(self) -> bool:
"""
......@@ -1220,7 +1221,9 @@ class Trainer:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
......@@ -1234,6 +1237,9 @@ class Trainer:
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
......@@ -1250,6 +1256,7 @@ class Trainer:
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
)
self.log(output.metrics)
......@@ -1261,7 +1268,7 @@ class Trainer:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput:
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
......@@ -1272,6 +1279,9 @@ class Trainer:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
.. note::
......@@ -1291,10 +1301,14 @@ class Trainer:
test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction")
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
def prediction_loop(
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
......@@ -1346,7 +1360,7 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if loss is not None:
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
......@@ -1410,7 +1424,11 @@ class Trainer:
return nested_numpify(tensors)
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
......@@ -1427,6 +1445,9 @@ class Trainer:
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
......@@ -1434,6 +1455,11 @@ class Trainer:
"""
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
if self.args.fp16 and _use_native_amp:
......@@ -1442,16 +1468,21 @@ class Trainer:
else:
outputs = model(**inputs)
if has_labels:
loss = outputs[0].mean().detach()
logits = outputs[1:]
if isinstance(outputs, dict):
loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
logits = outputs[:]
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
if prediction_loss_only:
return (loss, None, None)
......
......@@ -812,6 +812,15 @@ class TFLongformerForMaskedLM:
requires_tf(self)
class TFLongformerForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_tf(self)
......@@ -821,6 +830,24 @@ class TFLongformerForQuestionAnswering:
requires_tf(self)
class TFLongformerForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerForTokenClassification:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFLongformerModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
......
......@@ -129,7 +129,7 @@ class LongformerModelTester:
output_without_mask = model(input_ids)["last_hidden_state"]
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
......@@ -141,7 +141,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_model_with_global_attention_mask(
def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerModel(config=config)
......@@ -163,7 +163,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_for_masked_lm(
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerForMaskedLM(config=config)
......@@ -172,7 +172,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_longformer_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerForQuestionAnswering(config=config)
......@@ -189,7 +189,7 @@ class LongformerModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_longformer_for_sequence_classification(
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
......@@ -199,7 +199,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_longformer_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
......@@ -209,7 +209,7 @@ class LongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_longformer_for_multiple_choice(
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
......@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_longformer_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_attention_mask_determinism(self):
def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model_global_attention_mask(self):
def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self):
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self):
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@require_torch
......@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
) # long input
input_ids = input_ids.to(torch_device)
loss, prediction_scores = model(input_ids, labels=input_ids)
loss, prediction_scores = model(input_ids, labels=input_ids).to_tuple()
expected_loss = torch.tensor(0.0074, device=torch_device)
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
......
......@@ -340,6 +340,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
@require_tf
class TFBertModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment