Unverified Commit 9c172564 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Longformer] Multiple choice for longformer (#4645)

* add multiple choice for longformer

* add models to docs

* adapt docstring

* add test to longformer

* add longformer for mc in init and modeling auto

* fix tests
parent 91487cbb
...@@ -94,3 +94,17 @@ TFAlbertForSequenceClassification ...@@ -94,3 +94,17 @@ TFAlbertForSequenceClassification
.. autoclass:: transformers.TFAlbertForSequenceClassification .. autoclass:: transformers.TFAlbertForSequenceClassification
:members: :members:
TFAlbertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAlbertForMultipleChoice
:members:
TFAlbertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAlbertForQuestionAnswering
:members:
...@@ -74,3 +74,18 @@ LongformerForQuestionAnswering ...@@ -74,3 +74,18 @@ LongformerForQuestionAnswering
.. autoclass:: transformers.LongformerForQuestionAnswering .. autoclass:: transformers.LongformerForQuestionAnswering
:members: :members:
LongformerForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LongformerForMultipleChoice
:members:
LongformerForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.LongformerForTokenClassification
:members:
...@@ -74,6 +74,13 @@ RobertaForSequenceClassification ...@@ -74,6 +74,13 @@ RobertaForSequenceClassification
:members: :members:
RobertaForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.RobertaForMultipleChoice
:members:
RobertaForTokenClassification RobertaForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -326,6 +326,7 @@ if is_torch_available(): ...@@ -326,6 +326,7 @@ if is_torch_available():
LongformerModel, LongformerModel,
LongformerForMaskedLM, LongformerForMaskedLM,
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForMultipleChoice,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
......
...@@ -104,6 +104,7 @@ from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, G ...@@ -104,6 +104,7 @@ from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, G
from .modeling_longformer import ( from .modeling_longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
LongformerForMaskedLM, LongformerForMaskedLM,
LongformerForMultipleChoice,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
...@@ -297,6 +298,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( ...@@ -297,6 +298,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[ [
(CamembertConfig, CamembertForMultipleChoice), (CamembertConfig, CamembertForMultipleChoice),
(XLMRobertaConfig, XLMRobertaForMultipleChoice), (XLMRobertaConfig, XLMRobertaForMultipleChoice),
(LongformerConfig, LongformerForMultipleChoice),
(RobertaConfig, RobertaForMultipleChoice), (RobertaConfig, RobertaForMultipleChoice),
(BertConfig, BertForMultipleChoice), (BertConfig, BertForMultipleChoice),
(XLNetConfig, XLNetForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice),
......
...@@ -543,7 +543,7 @@ BERT_START_DOCSTRING = r""" ...@@ -543,7 +543,7 @@ BERT_START_DOCSTRING = r"""
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`. Indices can be obtained using :class:`transformers.BertTokenizer`.
...@@ -551,19 +551,19 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -551,19 +551,19 @@ BERT_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
...@@ -632,7 +632,7 @@ class BertModel(BertPreTrainedModel): ...@@ -632,7 +632,7 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -759,7 +759,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -759,7 +759,7 @@ class BertForPreTraining(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -859,7 +859,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -859,7 +859,7 @@ class BertForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.cls.predictions.decoder return self.cls.predictions.decoder
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -992,7 +992,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -992,7 +992,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1081,7 +1081,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1081,7 +1081,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1177,7 +1177,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1177,7 +1177,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1278,7 +1278,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1278,7 +1278,7 @@ class BertForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1375,7 +1375,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1375,7 +1375,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -411,7 +411,7 @@ LONGFORMER_START_DOCSTRING = r""" ...@@ -411,7 +411,7 @@ LONGFORMER_START_DOCSTRING = r"""
LONGFORMER_INPUTS_DOCSTRING = r""" LONGFORMER_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.LonmgformerTokenizer`. Indices can be obtained using :class:`transformers.LonmgformerTokenizer`.
...@@ -419,7 +419,7 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ...@@ -419,7 +419,7 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens). Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
task-specific finetuning because it makes the model more flexible at representing the task. For example, task-specific finetuning because it makes the model more flexible at representing the task. For example,
...@@ -431,13 +431,13 @@ LONGFORMER_INPUTS_DOCSTRING = r""" ...@@ -431,13 +431,13 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). ``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
...@@ -537,7 +537,7 @@ class LongformerModel(RobertaModel): ...@@ -537,7 +537,7 @@ class LongformerModel(RobertaModel):
return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds return padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -641,7 +641,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): ...@@ -641,7 +641,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -729,7 +729,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel): ...@@ -729,7 +729,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
self.longformer = LongformerModel(config) self.longformer = LongformerModel(config)
self.classifier = LongformerClassificationHead(config) self.classifier = LongformerClassificationHead(config)
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -866,7 +866,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -866,7 +866,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids, input_ids,
...@@ -993,7 +993,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -993,7 +993,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1070,3 +1070,100 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -1070,3 +1070,100 @@ class LongformerForTokenClassification(BertPreTrainedModel):
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # (loss), scores, (hidden_states), (attentions) return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings(
"""Longformer Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
LONGFORMER_START_DOCSTRING,
)
class LongformerForMultipleChoice(BertPreTrainedModel):
config_class = LongformerConfig
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "longformer"
def __init__(self, config):
super().__init__(config)
self.longformer = LongformerModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.init_weights()
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward(
self,
input_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
position_ids=None,
inputs_embeds=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
Examples::
from transformers import LongformerTokenizer, LongformerForTokenClassification
import torch
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.longformer(
flat_input_ids,
position_ids=flat_position_ids,
token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
...@@ -95,7 +95,7 @@ ROBERTA_START_DOCSTRING = r""" ...@@ -95,7 +95,7 @@ ROBERTA_START_DOCSTRING = r"""
ROBERTA_INPUTS_DOCSTRING = r""" ROBERTA_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.RobertaTokenizer`. Indices can be obtained using :class:`transformers.RobertaTokenizer`.
...@@ -103,19 +103,19 @@ ROBERTA_INPUTS_DOCSTRING = r""" ...@@ -103,19 +103,19 @@ ROBERTA_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
...@@ -175,7 +175,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -175,7 +175,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head.decoder return self.lm_head.decoder
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -286,7 +286,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel): ...@@ -286,7 +286,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
self.roberta = RobertaModel(config) self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -379,7 +379,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel): ...@@ -379,7 +379,7 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -479,7 +479,7 @@ class RobertaForTokenClassification(BertPreTrainedModel): ...@@ -479,7 +479,7 @@ class RobertaForTokenClassification(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -598,7 +598,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel): ...@@ -598,7 +598,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids, input_ids,
......
...@@ -628,7 +628,7 @@ ALBERT_START_DOCSTRING = r""" ...@@ -628,7 +628,7 @@ ALBERT_START_DOCSTRING = r"""
ALBERT_INPUTS_DOCSTRING = r""" ALBERT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.AlbertTokenizer`. Indices can be obtained using :class:`transformers.AlbertTokenizer`.
...@@ -636,19 +636,19 @@ ALBERT_INPUTS_DOCSTRING = r""" ...@@ -636,19 +636,19 @@ ALBERT_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional, defaults to :obj:`None`): attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
...@@ -676,7 +676,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -676,7 +676,7 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.albert = TFAlbertMainLayer(config, name="albert") self.albert = TFAlbertMainLayer(config, name="albert")
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
...@@ -734,7 +734,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel): ...@@ -734,7 +734,7 @@ class TFAlbertForPreTraining(TFAlbertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.albert.embeddings return self.albert.embeddings
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
...@@ -795,7 +795,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel): ...@@ -795,7 +795,7 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.albert.embeddings return self.albert.embeddings
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
...@@ -852,7 +852,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel): ...@@ -852,7 +852,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
...@@ -908,7 +908,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel): ...@@ -908,7 +908,7 @@ class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
...@@ -983,7 +983,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel): ...@@ -983,7 +983,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel):
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(ALBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def call( def call(
self, self,
inputs, inputs,
......
...@@ -621,7 +621,7 @@ BERT_START_DOCSTRING = r""" ...@@ -621,7 +621,7 @@ BERT_START_DOCSTRING = r"""
BERT_INPUTS_DOCSTRING = r""" BERT_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`. Indices can be obtained using :class:`transformers.BertTokenizer`.
...@@ -629,19 +629,19 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -629,19 +629,19 @@ BERT_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
`What are token type IDs? <../glossary.html#token-type-ids>`__ `What are token type IDs? <../glossary.html#token-type-ids>`__
position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Indices of positions of each input sequence tokens in the position embeddings. Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``. Selected in the range ``[0, config.max_position_embeddings - 1]``.
...@@ -669,7 +669,7 @@ class TFBertModel(TFBertPreTrainedModel): ...@@ -669,7 +669,7 @@ class TFBertModel(TFBertPreTrainedModel):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Returns: Returns:
...@@ -726,7 +726,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel): ...@@ -726,7 +726,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.bert.embeddings return self.bert.embeddings
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
...@@ -782,7 +782,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel): ...@@ -782,7 +782,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.bert.embeddings return self.bert.embeddings
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
...@@ -832,7 +832,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel): ...@@ -832,7 +832,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, name="bert")
self.nsp = TFBertNSPHead(config, name="nsp___cls") self.nsp = TFBertNSPHead(config, name="nsp___cls")
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
...@@ -888,7 +888,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel): ...@@ -888,7 +888,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
...@@ -954,7 +954,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel): ...@@ -954,7 +954,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
""" """
return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)} return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def call( def call(
self, self,
inputs, inputs,
...@@ -1065,7 +1065,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel): ...@@ -1065,7 +1065,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
) )
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
...@@ -1122,7 +1122,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel): ...@@ -1122,7 +1122,7 @@ class TFBertForQuestionAnswering(TFBertPreTrainedModel):
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
) )
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
r""" r"""
Return: Return:
......
...@@ -506,7 +506,7 @@ XLNET_START_DOCSTRING = r""" ...@@ -506,7 +506,7 @@ XLNET_START_DOCSTRING = r"""
XLNET_INPUTS_DOCSTRING = r""" XLNET_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`transformers.BertTokenizer`. Indices can be obtained using :class:`transformers.BertTokenizer`.
...@@ -514,7 +514,7 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -514,7 +514,7 @@ XLNET_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.encode_plus` for details. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
...@@ -535,13 +535,13 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -535,13 +535,13 @@ XLNET_INPUTS_DOCSTRING = r"""
Mask to indicate the output tokens to use. Mask to indicate the output tokens to use.
If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token. If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
Only used during pretraining for partial prediction or for sequential decoding (generation). Only used during pretraining for partial prediction or for sequential decoding (generation).
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token. The classifier token should be represented by a ``2``. corresponds to a `sentence B` token. The classifier token should be represented by a ``2``.
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
input_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): input_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding. Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
Kept for compatibility with the original code base. Kept for compatibility with the original code base.
...@@ -688,7 +688,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -688,7 +688,7 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(self.device) pos_emb = pos_emb.to(self.device)
return pos_emb return pos_emb
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -971,7 +971,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -971,7 +971,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return inputs return inputs
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1091,7 +1091,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1091,7 +1091,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1196,7 +1196,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1196,7 +1196,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1305,7 +1305,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1305,7 +1305,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1418,7 +1418,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1418,7 +1418,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1544,7 +1544,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1544,7 +1544,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self.init_weights() self.init_weights()
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
......
...@@ -32,6 +32,7 @@ if is_torch_available(): ...@@ -32,6 +32,7 @@ if is_torch_available():
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerForQuestionAnswering, LongformerForQuestionAnswering,
LongformerForMultipleChoice,
) )
...@@ -228,6 +229,29 @@ class LongformerModelTester(object): ...@@ -228,6 +229,29 @@ class LongformerModelTester(object):
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]) self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_longformer_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = LongformerForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -298,6 +322,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -298,6 +322,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
class LongformerModelIntegrationTest(unittest.TestCase): class LongformerModelIntegrationTest(unittest.TestCase):
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment