Commit a75c64d8 authored by Lysandre's avatar Lysandre
Browse files

Black 20 release

parent e78c1103
...@@ -329,7 +329,13 @@ class TFGenerationMixin: ...@@ -329,7 +329,13 @@ class TFGenerationMixin:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# create empty decoder_input_ids # create empty decoder_input_ids
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id input_ids = (
tf.ones(
(effective_batch_size * num_beams, 1),
dtype=tf.int32,
)
* decoder_start_token_id
)
cur_len = 1 cur_len = 1
assert ( assert (
...@@ -422,8 +428,8 @@ class TFGenerationMixin: ...@@ -422,8 +428,8 @@ class TFGenerationMixin:
attention_mask, attention_mask,
use_cache, use_cache,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
""" """
# length of generated sentences / unfinished sentences # length of generated sentences / unfinished sentences
...@@ -587,8 +593,7 @@ class TFGenerationMixin: ...@@ -587,8 +593,7 @@ class TFGenerationMixin:
attention_mask, attention_mask,
use_cache, use_cache,
): ):
""" Generate sequences for each example with beam search. """Generate sequences for each example with beam search."""
"""
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
...@@ -960,14 +965,14 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): ...@@ -960,14 +965,14 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size, vocabulary size) logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
logits_shape = shape_list(logits) logits_shape = shape_list(logits)
...@@ -1001,7 +1006,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ...@@ -1001,7 +1006,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = tf.concat( sorted_indices_to_remove = tf.concat(
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1, [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]],
-1,
) )
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices) indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
...@@ -1027,9 +1033,9 @@ def set_tensor_by_indices_to_value(tensor, indices, value): ...@@ -1027,9 +1033,9 @@ def set_tensor_by_indices_to_value(tensor, indices, value):
def sample_without_replacement(logits, num_samples): def sample_without_replacement(logits, num_samples):
""" """
categorical sampling witouth replacement is currently not implemented categorical sampling witouth replacement is currently not implemented
the gumbel-max trick will do for now the gumbel-max trick will do for now
see https://github.com/tensorflow/tensorflow/issues/9260 for more info see https://github.com/tensorflow/tensorflow/issues/9260 for more info
""" """
z = -tf.math.log(tf.random.uniform(shape_list(logits), 0, 1)) z = -tf.math.log(tf.random.uniform(shape_list(logits), 0, 1))
_, indices = tf.nn.top_k(logits + z, num_samples) _, indices = tf.nn.top_k(logits + z, num_samples)
......
...@@ -83,7 +83,11 @@ class GenerationMixin: ...@@ -83,7 +83,11 @@ class GenerationMixin:
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
self.enforce_repetition_penalty_( self.enforce_repetition_penalty_(
scores, batch_size, num_beams, input_ids, repetition_penalty, scores,
batch_size,
num_beams,
input_ids,
repetition_penalty,
) )
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
...@@ -324,7 +328,10 @@ class GenerationMixin: ...@@ -324,7 +328,10 @@ class GenerationMixin:
"or a `bos_token_id` (integer >= 0) as a first token to start the generation." "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
) )
input_ids = torch.full( input_ids = torch.full(
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device, (batch_size, 1),
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
) )
else: else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
...@@ -514,8 +521,8 @@ class GenerationMixin: ...@@ -514,8 +521,8 @@ class GenerationMixin:
use_cache, use_cache,
model_specific_kwargs, model_specific_kwargs,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
""" """
# length of generated sentences / unfinished sentences # length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1) unfinished_sents = input_ids.new(batch_size).fill_(1)
...@@ -619,8 +626,7 @@ class GenerationMixin: ...@@ -619,8 +626,7 @@ class GenerationMixin:
use_cache, use_cache,
model_specific_kwargs, model_specific_kwargs,
): ):
""" Generate sequences for each example with beam search. """Generate sequences for each example with beam search."""
"""
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
...@@ -749,7 +755,8 @@ class GenerationMixin: ...@@ -749,7 +755,8 @@ class GenerationMixin:
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
continue continue
generated_hyps[batch_idx].add( generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), beam_token_score.item(), input_ids[effective_beam_id].clone(),
beam_token_score.item(),
) )
else: else:
# add next predicted token since it is not eos_token # add next predicted token since it is not eos_token
...@@ -806,7 +813,8 @@ class GenerationMixin: ...@@ -806,7 +813,8 @@ class GenerationMixin:
assert torch.all( assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx], next_scores[:, :num_beams][batch_idx],
beam_scores.view(batch_size, num_beams)[batch_idx],
) )
# need to add best num_beams hypotheses to generated hyps # need to add best num_beams hypotheses to generated hyps
...@@ -916,7 +924,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter ...@@ -916,7 +924,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args: Args:
scores: logits distribution of shape (batch size, vocabulary size) scores: logits distribution of shape (batch size, vocabulary size)
...@@ -946,14 +954,14 @@ def top_k_top_p_filtering( ...@@ -946,14 +954,14 @@ def top_k_top_p_filtering(
filter_value: float = -float("Inf"), filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
) -> Tensor: ) -> Tensor:
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size, vocabulary size) logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens_to_keep per batch example in the output Make sure we keep at least min_tokens_to_keep per batch example in the output
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
""" """
if top_k > 0: if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
......
...@@ -36,20 +36,20 @@ logger = logging.get_logger(__name__) ...@@ -36,20 +36,20 @@ logger = logging.get_logger(__name__)
class ModelCard: class ModelCard:
r""" Structured Model Card class. r"""Structured Model Card class.
Store model card as well as methods for loading/downloading/saving model cards. Store model card as well as methods for loading/downloading/saving model cards.
Please read the following paper for details and explanation on the sections: Please read the following paper for details and explanation on the sections:
"Model Cards for Model Reporting" "Model Cards for Model Reporting"
by Margaret Mitchell, Simone Wu, by Margaret Mitchell, Simone Wu,
Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards.
Link: https://arxiv.org/abs/1810.03993 Link: https://arxiv.org/abs/1810.03993
Note: Note:
A model card can be loaded and saved to disk. A model card can be loaded and saved to disk.
Parameters: Parameters:
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -73,8 +73,7 @@ class ModelCard: ...@@ -73,8 +73,7 @@ class ModelCard:
raise err raise err
def save_pretrained(self, save_directory_or_file): def save_pretrained(self, save_directory_or_file):
""" Save a model card object to the directory or file `save_directory_or_file`. """Save a model card object to the directory or file `save_directory_or_file`."""
"""
if os.path.isdir(save_directory_or_file): if os.path.isdir(save_directory_or_file):
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME) output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
...@@ -86,7 +85,7 @@ class ModelCard: ...@@ -86,7 +85,7 @@ class ModelCard:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card. r"""Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
Parameters: Parameters:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
......
...@@ -302,7 +302,10 @@ class AlbertLayer(nn.Module): ...@@ -302,7 +302,10 @@ class AlbertLayer(nn.Module):
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = apply_chunking_to_forward( ffn_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0], self.ff_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[0],
) )
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
...@@ -397,8 +400,8 @@ class AlbertTransformer(nn.Module): ...@@ -397,8 +400,8 @@ class AlbertTransformer(nn.Module):
class AlbertPreTrainedModel(PreTrainedModel): class AlbertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = AlbertConfig config_class = AlbertConfig
...@@ -406,8 +409,7 @@ class AlbertPreTrainedModel(PreTrainedModel): ...@@ -406,8 +409,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -543,17 +545,17 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -543,17 +545,17 @@ class AlbertModel(AlbertPreTrainedModel):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups. ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups.
If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there
is a total of 4 different layers. is a total of 4 different layers.
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer, These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
while [2,3] correspond to the two inner groups of the second hidden layer. while [2,3] correspond to the two inner groups of the second hidden layer.
Any layer with in index other than [0,1,2,3] will result in an error. Any layer with in index other than [0,1,2,3] will result in an error.
See base class PreTrainedModel for more information about head pruning See base class PreTrainedModel for more information about head pruning
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
group_idx = int(layer / self.config.inner_group_num) group_idx = int(layer / self.config.inner_group_num)
...@@ -672,34 +674,34 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -672,34 +674,34 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
**kwargs, **kwargs,
): ):
r""" r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): sentence_order_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
Indices should be in ``[0, 1]``. Indices should be in ``[0, 1]``.
``0`` indicates original order (sequence A, then sequence B), ``0`` indicates original order (sequence A, then sequence B),
``1`` indicates switched order (sequence B, then sequence A). ``1`` indicates switched order (sequence B, then sequence A).
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated. Used to hide legacy arguments that have been deprecated.
Returns: Returns:
Examples:: Examples::
>>> from transformers import AlbertTokenizer, AlbertForPreTraining >>> from transformers import AlbertTokenizer, AlbertForPreTraining
>>> import torch >>> import torch
>>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') >>> tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
>>> model = AlbertForPreTraining.from_pretrained('albert-base-v2', return_dict=True) >>> model = AlbertForPreTraining.from_pretrained('albert-base-v2', return_dict=True)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids) >>> outputs = model(input_ids)
>>> prediction_logits = outputs.prediction_logits >>> prediction_logits = outputs.prediction_logits
>>> sop_logits = outputs.sop_logits >>> sop_logits = outputs.sop_logits
""" """
...@@ -787,7 +789,8 @@ class AlbertSOPHead(nn.Module): ...@@ -787,7 +789,8 @@ class AlbertSOPHead(nn.Module):
@add_start_docstrings( @add_start_docstrings(
"Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, "Albert Model with a `language modeling` head on top.",
ALBERT_START_DOCSTRING,
) )
class AlbertForMaskedLM(AlbertPreTrainedModel): class AlbertForMaskedLM(AlbertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
...@@ -952,7 +955,10 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -952,7 +955,10 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1033,7 +1039,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -1033,7 +1039,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1215,5 +1224,8 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel): ...@@ -1215,5 +1224,8 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -374,12 +374,12 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( ...@@ -374,12 +374,12 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
class AutoModel: class AutoModel:
r""" r"""
:class:`~transformers.AutoModel` is a generic model class :class:`~transformers.AutoModel` is a generic model class
that will be instantiated as one of the base model classes of the library that will be instantiated as one of the base model classes of the library
when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
or the `AutoModel.from_config(config)` class methods. or the `AutoModel.from_config(config)` class methods.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -391,7 +391,7 @@ class AutoModel: ...@@ -391,7 +391,7 @@ class AutoModel:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -433,7 +433,7 @@ class AutoModel: ...@@ -433,7 +433,7 @@ class AutoModel:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -531,11 +531,11 @@ class AutoModel: ...@@ -531,11 +531,11 @@ class AutoModel:
class AutoModelForPreTraining: class AutoModelForPreTraining:
r""" r"""
:class:`~transformers.AutoModelForPreTraining` is a generic model class :class:`~transformers.AutoModelForPreTraining` is a generic model class
that will be instantiated as one of the model classes of the library -with the architecture used for pretraining this model– when created with the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)` that will be instantiated as one of the model classes of the library -with the architecture used for pretraining this model– when created with the `AutoModelForPreTraining.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -547,7 +547,7 @@ class AutoModelForPreTraining: ...@@ -547,7 +547,7 @@ class AutoModelForPreTraining:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -589,7 +589,7 @@ class AutoModelForPreTraining: ...@@ -589,7 +589,7 @@ class AutoModelForPreTraining:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration. r"""Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
based on the `model_type` property of the config object, or when it's missing, based on the `model_type` property of the config object, or when it's missing,
...@@ -680,12 +680,12 @@ class AutoModelForPreTraining: ...@@ -680,12 +680,12 @@ class AutoModelForPreTraining:
class AutoModelWithLMHead: class AutoModelWithLMHead:
r""" r"""
:class:`~transformers.AutoModelWithLMHead` is a generic model class :class:`~transformers.AutoModelWithLMHead` is a generic model class
that will be instantiated as one of the language modeling model classes of the library that will be instantiated as one of the language modeling model classes of the library
when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -697,7 +697,7 @@ class AutoModelWithLMHead: ...@@ -697,7 +697,7 @@ class AutoModelWithLMHead:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -743,7 +743,7 @@ class AutoModelWithLMHead: ...@@ -743,7 +743,7 @@ class AutoModelWithLMHead:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -839,12 +839,12 @@ class AutoModelWithLMHead: ...@@ -839,12 +839,12 @@ class AutoModelWithLMHead:
class AutoModelForCausalLM: class AutoModelForCausalLM:
r""" r"""
:class:`~transformers.AutoModelForCausalLM` is a generic model class :class:`~transformers.AutoModelForCausalLM` is a generic model class
that will be instantiated as one of the language modeling model classes of the library that will be instantiated as one of the language modeling model classes of the library
when created with the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -856,7 +856,7 @@ class AutoModelForCausalLM: ...@@ -856,7 +856,7 @@ class AutoModelForCausalLM:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -893,7 +893,7 @@ class AutoModelForCausalLM: ...@@ -893,7 +893,7 @@ class AutoModelForCausalLM:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -976,12 +976,12 @@ class AutoModelForCausalLM: ...@@ -976,12 +976,12 @@ class AutoModelForCausalLM:
class AutoModelForMaskedLM: class AutoModelForMaskedLM:
r""" r"""
:class:`~transformers.AutoModelForMaskedLM` is a generic model class :class:`~transformers.AutoModelForMaskedLM` is a generic model class
that will be instantiated as one of the language modeling model classes of the library that will be instantiated as one of the language modeling model classes of the library
when created with the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -993,7 +993,7 @@ class AutoModelForMaskedLM: ...@@ -993,7 +993,7 @@ class AutoModelForMaskedLM:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1033,7 +1033,7 @@ class AutoModelForMaskedLM: ...@@ -1033,7 +1033,7 @@ class AutoModelForMaskedLM:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1119,12 +1119,12 @@ class AutoModelForMaskedLM: ...@@ -1119,12 +1119,12 @@ class AutoModelForMaskedLM:
class AutoModelForSeq2SeqLM: class AutoModelForSeq2SeqLM:
r""" r"""
:class:`~transformers.AutoModelForSeq2SeqLM` is a generic model class :class:`~transformers.AutoModelForSeq2SeqLM` is a generic model class
that will be instantiated as one of the language modeling model classes of the library that will be instantiated as one of the language modeling model classes of the library
when created with the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -1136,7 +1136,7 @@ class AutoModelForSeq2SeqLM: ...@@ -1136,7 +1136,7 @@ class AutoModelForSeq2SeqLM:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1172,7 +1172,7 @@ class AutoModelForSeq2SeqLM: ...@@ -1172,7 +1172,7 @@ class AutoModelForSeq2SeqLM:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1254,12 +1254,12 @@ class AutoModelForSeq2SeqLM: ...@@ -1254,12 +1254,12 @@ class AutoModelForSeq2SeqLM:
class AutoModelForSequenceClassification: class AutoModelForSequenceClassification:
r""" r"""
:class:`~transformers.AutoModelForSequenceClassification` is a generic model class :class:`~transformers.AutoModelForSequenceClassification` is a generic model class
that will be instantiated as one of the sequence classification model classes of the library that will be instantiated as one of the sequence classification model classes of the library
when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -1271,7 +1271,7 @@ class AutoModelForSequenceClassification: ...@@ -1271,7 +1271,7 @@ class AutoModelForSequenceClassification:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1313,7 +1313,7 @@ class AutoModelForSequenceClassification: ...@@ -1313,7 +1313,7 @@ class AutoModelForSequenceClassification:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the sequence classification model classes of the library r"""Instantiates one of the sequence classification model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1406,12 +1406,12 @@ class AutoModelForSequenceClassification: ...@@ -1406,12 +1406,12 @@ class AutoModelForSequenceClassification:
class AutoModelForQuestionAnswering: class AutoModelForQuestionAnswering:
r""" r"""
:class:`~transformers.AutoModelForQuestionAnswering` is a generic model class :class:`~transformers.AutoModelForQuestionAnswering` is a generic model class
that will be instantiated as one of the question answering model classes of the library that will be instantiated as one of the question answering model classes of the library
when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -1423,7 +1423,7 @@ class AutoModelForQuestionAnswering: ...@@ -1423,7 +1423,7 @@ class AutoModelForQuestionAnswering:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1462,7 +1462,7 @@ class AutoModelForQuestionAnswering: ...@@ -1462,7 +1462,7 @@ class AutoModelForQuestionAnswering:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1551,12 +1551,12 @@ class AutoModelForQuestionAnswering: ...@@ -1551,12 +1551,12 @@ class AutoModelForQuestionAnswering:
class AutoModelForTokenClassification: class AutoModelForTokenClassification:
r""" r"""
:class:`~transformers.AutoModelForTokenClassification` is a generic model class :class:`~transformers.AutoModelForTokenClassification` is a generic model class
that will be instantiated as one of the token classification model classes of the library that will be instantiated as one of the token classification model classes of the library
when created with the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
...@@ -1568,7 +1568,7 @@ class AutoModelForTokenClassification: ...@@ -1568,7 +1568,7 @@ class AutoModelForTokenClassification:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1611,7 +1611,7 @@ class AutoModelForTokenClassification: ...@@ -1611,7 +1611,7 @@ class AutoModelForTokenClassification:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1703,12 +1703,12 @@ class AutoModelForTokenClassification: ...@@ -1703,12 +1703,12 @@ class AutoModelForTokenClassification:
class AutoModelForMultipleChoice: class AutoModelForMultipleChoice:
r""" r"""
:class:`~transformers.AutoModelForMultipleChoice` is a generic model class :class:`~transformers.AutoModelForMultipleChoice` is a generic model class
that will be instantiated as one of the multiple choice model classes of the library that will be instantiated as one of the multiple choice model classes of the library
when created with the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)` when created with the `AutoModelForMultipleChoice.from_pretrained(pretrained_model_name_or_path)`
class method. class method.
This class cannot be instantiated using `__init__()` (throws an error). This class cannot be instantiated using `__init__()` (throws an error).
""" """
def __init__(self): def __init__(self):
......
...@@ -223,7 +223,9 @@ class EncoderLayer(nn.Module): ...@@ -223,7 +223,9 @@ class EncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = SelfAttention( self.self_attn = SelfAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, self.embed_dim,
config.encoder_attention_heads,
dropout=config.attention_dropout,
) )
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
...@@ -297,7 +299,10 @@ class BartEncoder(nn.Module): ...@@ -297,7 +299,10 @@ class BartEncoder(nn.Module):
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings, config.max_position_embeddings,
embed_dim,
self.padding_idx,
config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
...@@ -370,7 +375,9 @@ class DecoderLayer(nn.Module): ...@@ -370,7 +375,9 @@ class DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = SelfAttention( self.self_attn = SelfAttention(
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
...@@ -477,7 +484,10 @@ class BartDecoder(nn.Module): ...@@ -477,7 +484,10 @@ class BartDecoder(nn.Module):
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings, config.max_position_embeddings,
config.d_model,
self.padding_idx,
config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)] [DecoderLayer(config) for _ in range(config.decoder_layers)]
...@@ -695,7 +705,10 @@ class SelfAttention(nn.Module): ...@@ -695,7 +705,10 @@ class SelfAttention(nn.Module):
# This is part of a workaround to get around fork/join parallelism not supporting Optional types. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0: if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,) assert key_padding_mask is None or key_padding_mask.size()[:2] == (
bsz,
src_len,
)
if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
...@@ -703,7 +716,11 @@ class SelfAttention(nn.Module): ...@@ -703,7 +716,11 @@ class SelfAttention(nn.Module):
attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,) attn_probs = F.dropout(
attn_weights,
p=self.dropout,
training=self.training,
)
assert v is not None assert v is not None
attn_output = torch.bmm(attn_probs, v) attn_output = torch.bmm(attn_probs, v)
...@@ -754,7 +771,11 @@ class BartClassificationHead(nn.Module): ...@@ -754,7 +771,11 @@ class BartClassificationHead(nn.Module):
# This can trivially be shared with RobertaClassificationHead # This can trivially be shared with RobertaClassificationHead
def __init__( def __init__(
self, input_dim, inner_dim, num_classes, pooler_dropout, self,
input_dim,
inner_dim,
num_classes,
pooler_dropout,
): ):
super().__init__() super().__init__()
self.dense = nn.Linear(input_dim, inner_dim) self.dense = nn.Linear(input_dim, inner_dim)
...@@ -819,7 +840,8 @@ def _get_shape(t): ...@@ -819,7 +840,8 @@ def _get_shape(t):
@add_start_docstrings( @add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, "The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
) )
class BartModel(PretrainedBartModel): class BartModel(PretrainedBartModel):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
...@@ -981,31 +1003,31 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -981,31 +1003,31 @@ class BartForConditionalGeneration(PretrainedBartModel):
**unused, **unused,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring). Indices should either be in ``[0, ..., config.vocab_size]`` or -100 (see ``input_ids`` docstring).
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens
with labels in ``[0, ..., config.vocab_size]``. with labels in ``[0, ..., config.vocab_size]``.
Returns: Returns:
Conditional generation example:: Conditional generation example::
# Mask filling only works for bart-large # Mask filling only works for bart-large
from transformers import BartTokenizer, BartForConditionalGeneration from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
TXT = "My friends are <mask> but they eat too many carbs." TXT = "My friends are <mask> but they eat too many carbs."
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] input_ids = tokenizer([TXT], return_tensors='pt')['input_ids']
logits = model(input_ids).logits logits = model(input_ids).logits
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0) probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5) values, predictions = probs.topk(5)
tokenizer.decode(predictions).split() tokenizer.decode(predictions).split()
# ['good', 'great', 'all', 'really', 'very'] # ['good', 'great', 'all', 'really', 'very']
""" """
if "lm_labels" in unused: if "lm_labels" in unused:
warnings.warn( warnings.warn(
...@@ -1116,7 +1138,10 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1116,7 +1138,10 @@ class BartForSequenceClassification(PretrainedBartModel):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.model = BartModel(config) self.model = BartModel(config)
self.classification_head = BartClassificationHead( self.classification_head = BartClassificationHead(
config.d_model, config.d_model, config.num_labels, config.classif_dropout, config.d_model,
config.d_model,
config.num_labels,
config.classif_dropout,
) )
self.model._init_weights(self.classification_head.dense) self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) self.model._init_weights(self.classification_head.out_proj)
...@@ -1279,7 +1304,10 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1279,7 +1304,10 @@ class BartForQuestionAnswering(PretrainedBartModel):
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
if not return_dict: if not return_dict:
output = (start_logits, end_logits,) + outputs[1:] output = (
start_logits,
end_logits,
) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output return ((total_loss,) + output) if total_loss is not None else output
return Seq2SeqQuestionAnsweringModelOutput( return Seq2SeqQuestionAnsweringModelOutput(
...@@ -1307,7 +1335,7 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1307,7 +1335,7 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
@staticmethod @staticmethod
def _init_weight(out: nn.Parameter): def _init_weight(out: nn.Parameter):
"""Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. """Identical to the XLM create_sinusoidal_embeddings except features are not interleaved.
The cos features are in the 2nd half of the vector. [dim // 2:] The cos features are in the 2nd half of the vector. [dim // 2:]
""" """
n_pos, dim = out.shape n_pos, dim = out.shape
position_enc = np.array( position_enc = np.array(
......
...@@ -89,8 +89,7 @@ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -89,8 +89,7 @@ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_bert(model, config, tf_checkpoint_path): def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model. """Load tf checkpoints in a pytorch model."""
"""
try: try:
import re import re
...@@ -174,8 +173,7 @@ BertLayerNorm = torch.nn.LayerNorm ...@@ -174,8 +173,7 @@ BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -343,7 +341,12 @@ class BertAttention(nn.Module): ...@@ -343,7 +341,12 @@ class BertAttention(nn.Module):
output_attentions=False, output_attentions=False,
): ):
self_outputs = self.self( self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
...@@ -403,7 +406,10 @@ class BertLayer(nn.Module): ...@@ -403,7 +406,10 @@ class BertLayer(nn.Module):
output_attentions=False, output_attentions=False,
): ):
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions=output_attentions, hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
...@@ -582,8 +588,8 @@ class BertPreTrainingHeads(nn.Module): ...@@ -582,8 +588,8 @@ class BertPreTrainingHeads(nn.Module):
class BertPreTrainedModel(PreTrainedModel): class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = BertConfig config_class = BertConfig
...@@ -733,9 +739,9 @@ class BertModel(BertPreTrainedModel): ...@@ -733,9 +739,9 @@ class BertModel(BertPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
...@@ -877,34 +883,34 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -877,34 +883,34 @@ class BertForPreTraining(BertPreTrainedModel):
**kwargs **kwargs
): ):
r""" r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
Indices should be in ``[0, 1]``. Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A, ``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence. ``1`` indicates sequence B is a random sequence.
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated. Used to hide legacy arguments that have been deprecated.
Returns: Returns:
Examples:: Examples::
>>> from transformers import BertTokenizer, BertForPreTraining >>> from transformers import BertTokenizer, BertForPreTraining
>>> import torch >>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased', return_dict=True) >>> model = BertForPreTraining.from_pretrained('bert-base-uncased', return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> prediction_logits = outptus.prediction_logits >>> prediction_logits = outptus.prediction_logits
>>> seq_relationship_logits = outputs.seq_relationship_logits >>> seq_relationship_logits = outputs.seq_relationship_logits
""" """
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
...@@ -986,36 +992,36 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -986,36 +992,36 @@ class BertLMHeadModel(BertPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder. if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the model is configured as a decoder. is used in the cross-attention if the model is configured as a decoder.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction). Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Returns: Returns:
Example:: Example::
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
>>> import torch >>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> config = BertConfig.from_pretrained("bert-base-cased") >>> config = BertConfig.from_pretrained("bert-base-cased")
>>> config.is_decoder = True >>> config.is_decoder = True
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config, return_dict=True) >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config, return_dict=True)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits >>> prediction_logits = outputs.logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1049,7 +1055,10 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1049,7 +1055,10 @@ class BertLMHeadModel(BertPreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput( return CausalLMOutput(
loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
...@@ -1173,7 +1182,8 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1173,7 +1182,8 @@ class BertForMaskedLM(BertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING, """Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
) )
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
...@@ -1200,29 +1210,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1200,29 +1210,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
Indices should be in ``[0, 1]``. Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A, ``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence. ``1`` indicates sequence B is a random sequence.
Returns: Returns:
Example:: Example::
>>> from transformers import BertTokenizer, BertForNextSentencePrediction >>> from transformers import BertTokenizer, BertForNextSentencePrediction
>>> import torch >>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased', return_dict=True) >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased', return_dict=True)
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1])) >>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
>>> logits = outputs.logits >>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1336,7 +1346,10 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1336,7 +1346,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1422,7 +1435,10 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1422,7 +1435,10 @@ class BertForMultipleChoice(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1505,7 +1521,10 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1505,7 +1521,10 @@ class BertForTokenClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
......
...@@ -68,7 +68,8 @@ class CamembertModel(RobertaModel): ...@@ -68,7 +68,8 @@ class CamembertModel(RobertaModel):
@add_start_docstrings( @add_start_docstrings(
"""CamemBERT Model with a `language modeling` head on top. """, CAMEMBERT_START_DOCSTRING, """CamemBERT Model with a `language modeling` head on top. """,
CAMEMBERT_START_DOCSTRING,
) )
class CamembertForMaskedLM(RobertaForMaskedLM): class CamembertForMaskedLM(RobertaForMaskedLM):
""" """
......
...@@ -212,16 +212,15 @@ class EncoderLayer(torch.nn.Module): ...@@ -212,16 +212,15 @@ class EncoderLayer(torch.nn.Module):
class CTRLPreTrainedModel(PreTrainedModel): class CTRLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = CTRLConfig config_class = CTRLConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -331,8 +330,8 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -331,8 +330,8 @@ class CTRLModel(CTRLPreTrainedModel):
self.w = new_embeddings self.w = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].multi_head_attention.prune_heads(heads) self.h[layer].multi_head_attention.prune_heads(heads)
......
...@@ -261,7 +261,12 @@ class TransformerBlock(nn.Module): ...@@ -261,7 +261,12 @@ class TransformerBlock(nn.Module):
""" """
# Self-Attention # Self-Attention
sa_output = self.attention( sa_output = self.attention(
query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask, output_attentions=output_attentions, query=x,
key=x,
value=x,
mask=attn_mask,
head_mask=head_mask,
output_attentions=output_attentions,
) )
if output_attentions: if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
...@@ -343,8 +348,8 @@ class Transformer(nn.Module): ...@@ -343,8 +348,8 @@ class Transformer(nn.Module):
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # # INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class DistilBertPreTrainedModel(PreTrainedModel): class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = DistilBertConfig config_class = DistilBertConfig
...@@ -352,8 +357,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -352,8 +357,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, nn.Embedding): if isinstance(module, nn.Embedding):
if module.weight.requires_grad: if module.weight.requires_grad:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
...@@ -432,9 +436,9 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -432,9 +436,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings.word_embeddings = new_embeddings self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.transformer.layer[layer].attention.prune_heads(heads) self.transformer.layer[layer].attention.prune_heads(heads)
...@@ -493,7 +497,8 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -493,7 +497,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"""DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING, """DistilBert Model with a `masked language modeling` head on top. """,
DISTILBERT_START_DOCSTRING,
) )
class DistilBertForMaskedLM(DistilBertPreTrainedModel): class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
...@@ -829,7 +834,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -829,7 +834,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -863,32 +871,32 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel): ...@@ -863,32 +871,32 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above) of the input tensors. (see `input_ids` above)
Returns: Returns:
Examples:: Examples::
>>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice >>> from transformers import DistilBertTokenizer, DistilBertForMultipleChoice
>>> import torch >>> import torch
>>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') >>> tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
>>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased', return_dict=True) >>> model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased', return_dict=True)
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife." >>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand." >>> choice1 = "It is eaten while held in the hand."
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
>>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True) >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors='pt', padding=True)
>>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1 >>> outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1
>>> # the linear classifier still needs to be trained >>> # the linear classifier still needs to be trained
>>> loss = outputs.loss >>> loss = outputs.loss
>>> logits = outputs.logits >>> logits = outputs.logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
...@@ -930,5 +938,8 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel): ...@@ -930,5 +938,8 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -265,8 +265,8 @@ class DPRSpanPredictor(PreTrainedModel): ...@@ -265,8 +265,8 @@ class DPRSpanPredictor(PreTrainedModel):
class DPRPretrainedContextEncoder(PreTrainedModel): class DPRPretrainedContextEncoder(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = DPRConfig config_class = DPRConfig
...@@ -278,8 +278,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel): ...@@ -278,8 +278,8 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
class DPRPretrainedQuestionEncoder(PreTrainedModel): class DPRPretrainedQuestionEncoder(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = DPRConfig config_class = DPRConfig
...@@ -291,8 +291,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): ...@@ -291,8 +291,8 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
class DPRPretrainedReader(PreTrainedModel): class DPRPretrainedReader(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = DPRConfig config_class = DPRConfig
...@@ -421,15 +421,15 @@ class DPRContextEncoder(DPRPretrainedContextEncoder): ...@@ -421,15 +421,15 @@ class DPRContextEncoder(DPRPretrainedContextEncoder):
return_dict=None, return_dict=None,
) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]: ) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
Examples:: Examples::
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base') tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', return_dict=True) model = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base', return_dict=True)
input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='pt')["input_ids"] input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='pt')["input_ids"]
embeddings = model(input_ids).pooler_output embeddings = model(input_ids).pooler_output
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
...@@ -499,15 +499,15 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): ...@@ -499,15 +499,15 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
return_dict=None, return_dict=None,
) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]: ) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
Examples:: Examples::
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base') tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')
model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base', return_dict=True) model = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base', return_dict=True)
input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='pt')["input_ids"] input_ids = tokenizer("Hello, is my dog cute ?", return_tensors='pt')["input_ids"]
embeddings = model(input_ids).pooler_output embeddings = model(input_ids).pooler_output
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -553,7 +553,8 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): ...@@ -553,7 +553,8 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
@add_start_docstrings( @add_start_docstrings(
"The bare DPRReader transformer outputting span predictions.", DPR_START_DOCSTRING, "The bare DPRReader transformer outputting span predictions.",
DPR_START_DOCSTRING,
) )
class DPRReader(DPRPretrainedReader): class DPRReader(DPRPretrainedReader):
def __init__(self, config: DPRConfig): def __init__(self, config: DPRConfig):
...@@ -574,23 +575,23 @@ class DPRReader(DPRPretrainedReader): ...@@ -574,23 +575,23 @@ class DPRReader(DPRPretrainedReader):
return_dict=None, return_dict=None,
) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]: ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
r""" r"""
Return: Return:
Examples:: Examples::
from transformers import DPRReader, DPRReaderTokenizer from transformers import DPRReader, DPRReaderTokenizer
tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base') tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base', return_dict=True) model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base', return_dict=True)
encoded_inputs = tokenizer( encoded_inputs = tokenizer(
questions=["What is love ?"], questions=["What is love ?"],
titles=["Haddaway"], titles=["Haddaway"],
texts=["'What Is Love' is a song recorded by the artist Haddaway"], texts=["'What Is Love' is a song recorded by the artist Haddaway"],
return_tensors='pt' return_tensors='pt'
) )
outputs = model(**encoded_inputs) outputs = model(**encoded_inputs)
start_logits = outputs.stat_logits start_logits = outputs.stat_logits
end_logits = outputs.end_logits end_logits = outputs.end_logits
relevance_logits = outputs.relevance_logits relevance_logits = outputs.relevance_logits
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......
...@@ -46,8 +46,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -46,8 +46,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"): def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
""" Load tf checkpoints in a pytorch model. """Load tf checkpoints in a pytorch model."""
"""
try: try:
import re import re
...@@ -179,8 +178,8 @@ class ElectraGeneratorPredictions(nn.Module): ...@@ -179,8 +178,8 @@ class ElectraGeneratorPredictions(nn.Module):
class ElectraPreTrainedModel(BertPreTrainedModel): class ElectraPreTrainedModel(BertPreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = ElectraConfig config_class = ElectraConfig
...@@ -311,9 +310,9 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -311,9 +310,9 @@ class ElectraModel(ElectraPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
...@@ -512,24 +511,24 @@ class ElectraForPreTraining(ElectraPreTrainedModel): ...@@ -512,24 +511,24 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see :obj:`input_ids` docstring) Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see :obj:`input_ids` docstring)
Indices should be in ``[0, 1]``. Indices should be in ``[0, 1]``.
``0`` indicates the token is an original token, ``0`` indicates the token is an original token,
``1`` indicates the token was replaced. ``1`` indicates the token was replaced.
Returns: Returns:
Examples:: Examples::
>>> from transformers import ElectraTokenizer, ElectraForPreTraining >>> from transformers import ElectraTokenizer, ElectraForPreTraining
>>> import torch >>> import torch
>>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator') >>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
>>> model = ElectraForPreTraining.from_pretrained('google/electra-small-discriminator') >>> model = ElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> logits = model(input_ids).logits >>> logits = model(input_ids).logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -836,7 +835,10 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel): ...@@ -836,7 +835,10 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
if not return_dict: if not return_dict:
output = (start_logits, end_logits,) + discriminator_hidden_states[1:] output = (
start_logits,
end_logits,
) + discriminator_hidden_states[1:]
return ((total_loss,) + output) if total_loss is not None else output return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput( return QuestionAnsweringModelOutput(
......
...@@ -28,11 +28,11 @@ logger = logging.get_logger(__name__) ...@@ -28,11 +28,11 @@ logger = logging.get_logger(__name__)
class EncoderDecoderModel(PreTrainedModel): class EncoderDecoderModel(PreTrainedModel):
r""" r"""
:class:`~transformers.EncoderDecoder` is a generic model class that will be :class:`~transformers.EncoderDecoder` is a generic model class that will be
instantiated as a transformer architecture with one of the base model instantiated as a transformer architecture with one of the base model
classes of the library as encoder and another one as classes of the library as encoder and another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder. class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
""" """
config_class = EncoderDecoderConfig config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder" base_model_prefix = "encoder_decoder"
...@@ -103,7 +103,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -103,7 +103,7 @@ class EncoderDecoderModel(PreTrainedModel):
*model_args, *model_args,
**kwargs **kwargs
) -> PreTrainedModel: ) -> PreTrainedModel:
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints. r"""Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
......
...@@ -240,7 +240,11 @@ class FlaubertModel(XLMModel): ...@@ -240,7 +240,11 @@ class FlaubertModel(XLMModel):
# self attention # self attention
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor, attn_mask, cache=cache, head_mask=head_mask[i], output_attentions=output_attentions, tensor,
attn_mask,
cache=cache,
head_mask=head_mask[i],
output_attentions=output_attentions,
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
......
...@@ -61,8 +61,7 @@ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -61,8 +61,7 @@ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """Load tf checkpoints in a pytorch model"""
"""
try: try:
import re import re
...@@ -324,8 +323,8 @@ class Block(nn.Module): ...@@ -324,8 +323,8 @@ class Block(nn.Module):
class GPT2PreTrainedModel(PreTrainedModel): class GPT2PreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = GPT2Config config_class = GPT2Config
...@@ -336,8 +335,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -336,8 +335,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -483,8 +481,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -483,8 +481,8 @@ class GPT2Model(GPT2PreTrainedModel):
self.wte = new_embeddings self.wte = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
...@@ -800,47 +798,47 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -800,47 +798,47 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
**kwargs, **kwargs,
): ):
r""" r"""
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input) mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
Index of the classification token in each input sequence. Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1[``. Selected in the range ``[0, input_ids.size(-1) - 1[``.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`) labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`) mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above) of the input tensors. (see `input_ids` above)
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated. Used to hide legacy arguments that have been deprecated.
Return: Return:
Examples:: Examples::
>>> import torch >>> import torch
>>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel >>> from transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
>>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2') >>> tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
>>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2, return_dict=True) >>> model = GPT2DoubleHeadsModel.from_pretrained('gpt2, return_dict=True)
>>> # Add a [CLS] to the vocabulary (we should train it also!) >>> # Add a [CLS] to the vocabulary (we should train it also!)
>>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'}) >>> num_added_tokens = tokenizer.add_special_tokens({'cls_token': '[CLS]'})
>>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
>>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
>>> encoded_choices = [tokenizer.encode(s) for s in choices] >>> encoded_choices = [tokenizer.encode(s) for s in choices]
>>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
>>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
>>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
>>> outputs = model(input_ids, mc_token_ids=mc_token_ids) >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
>>> lm_logits = outputs.lm_logits >>> lm_logits = outputs.lm_logits
>>> mc_logits = outputs.mc_logits >>> mc_logits = outputs.mc_logits
""" """
if "lm_labels" in kwargs: if "lm_labels" in kwargs:
......
...@@ -66,7 +66,7 @@ LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -66,7 +66,7 @@ LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
def _get_question_end_index(input_ids, sep_token_id): def _get_question_end_index(input_ids, sep_token_id):
""" """
Computes the index of the first occurance of `sep_token_id`. Computes the index of the first occurance of `sep_token_id`.
""" """
sep_token_indices = (input_ids == sep_token_id).nonzero() sep_token_indices = (input_ids == sep_token_id).nonzero()
...@@ -81,9 +81,9 @@ def _get_question_end_index(input_ids, sep_token_id): ...@@ -81,9 +81,9 @@ def _get_question_end_index(input_ids, sep_token_id):
def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True): def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
""" """
Computes global attention mask by putting attention on all tokens Computes global attention mask by putting attention on all tokens
before `sep_token_id` if `before_sep_token is True` else after before `sep_token_id` if `before_sep_token is True` else after
`sep_token_id`. `sep_token_id`.
""" """
question_end_index = _get_question_end_index(input_ids, sep_token_id) question_end_index = _get_question_end_index(input_ids, sep_token_id)
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
...@@ -135,7 +135,10 @@ class LongformerSelfAttention(nn.Module): ...@@ -135,7 +135,10 @@ class LongformerSelfAttention(nn.Module):
self.one_sided_attn_window_size = attention_window // 2 self.one_sided_attn_window_size = attention_window // 2
def forward( def forward(
self, hidden_states, attention_mask=None, output_attentions=False, self,
hidden_states,
attention_mask=None,
output_attentions=False,
): ):
""" """
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
...@@ -314,17 +317,17 @@ class LongformerSelfAttention(nn.Module): ...@@ -314,17 +317,17 @@ class LongformerSelfAttention(nn.Module):
@staticmethod @staticmethod
def _pad_and_diagonalize(chunked_hidden_states): def _pad_and_diagonalize(chunked_hidden_states):
"""shift every row 1 step right, converting columns into diagonals. """shift every row 1 step right, converting columns into diagonals.
Example: Example:
chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492, chunked_hidden_states: [ 0.4983, 2.6918, -0.0071, 1.0492,
-1.8348, 0.7672, 0.2986, 0.0285, -1.8348, 0.7672, 0.2986, 0.0285,
-0.7584, 0.4206, -0.0405, 0.1599, -0.7584, 0.4206, -0.0405, 0.1599,
2.0514, -1.1600, 0.5372, 0.2629 ] 2.0514, -1.1600, 0.5372, 0.2629 ]
window_overlap = num_rows = 4 window_overlap = num_rows = 4
(pad & diagonilize) => (pad & diagonilize) =>
[ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000 [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000
0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000 0.0000, 0.0000, -0.7584, 0.4206, -0.0405, 0.1599, 0.0000
0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ] 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
""" """
total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size() total_num_heads, num_chunks, window_overlap, hidden_dim = chunked_hidden_states.size()
chunked_hidden_states = F.pad( chunked_hidden_states = F.pad(
...@@ -442,7 +445,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -442,7 +445,7 @@ class LongformerSelfAttention(nn.Module):
self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int self, attn_probs: torch.Tensor, value: torch.Tensor, window_overlap: int
): ):
"""Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. """Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors.
Returned tensor will be of the same shape as `attn_probs`""" Returned tensor will be of the same shape as `attn_probs`"""
batch_size, seq_len, num_heads, head_dim = value.size() batch_size, seq_len, num_heads, head_dim = value.size()
assert seq_len % (window_overlap * 2) == 0 assert seq_len % (window_overlap * 2) == 0
...@@ -622,7 +625,10 @@ class LongformerSelfAttention(nn.Module): ...@@ -622,7 +625,10 @@ class LongformerSelfAttention(nn.Module):
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0 ] = -10000.0
global_attn_scores = global_attn_scores.masked_fill(is_index_masked[:, None, None, :], -10000.0,) global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
-10000.0,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
...@@ -676,9 +682,16 @@ class LongformerAttention(nn.Module): ...@@ -676,9 +682,16 @@ class LongformerAttention(nn.Module):
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward( def forward(
self, hidden_states, attention_mask=None, output_attentions=False, self,
hidden_states,
attention_mask=None,
output_attentions=False,
): ):
self_outputs = self.self(hidden_states, attention_mask, output_attentions,) self_outputs = self.self(
hidden_states,
attention_mask,
output_attentions,
)
attn_output = self.output(self_outputs[0], hidden_states) attn_output = self.output(self_outputs[0], hidden_states)
outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -694,9 +707,16 @@ class LongformerLayer(nn.Module): ...@@ -694,9 +707,16 @@ class LongformerLayer(nn.Module):
self.seq_len_dim = 1 self.seq_len_dim = 1
def forward( def forward(
self, hidden_states, attention_mask=None, output_attentions=False, self,
hidden_states,
attention_mask=None,
output_attentions=False,
): ):
self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) self_attn_outputs = self.attention(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
attn_output = self_attn_outputs[0] attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
...@@ -741,10 +761,16 @@ class LongformerEncoder(nn.Module): ...@@ -741,10 +761,16 @@ class LongformerEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, attention_mask, create_custom_forward(layer_module),
hidden_states,
attention_mask,
) )
else: else:
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,) layer_outputs = layer_module(
hidden_states,
attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
...@@ -762,9 +788,9 @@ class LongformerEncoder(nn.Module): ...@@ -762,9 +788,9 @@ class LongformerEncoder(nn.Module):
class LongformerPreTrainedModel(PreTrainedModel): class LongformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained a simple interface for downloading and loading pretrained
models. models.
""" """
config_class = LongformerConfig config_class = LongformerConfig
...@@ -896,9 +922,9 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -896,9 +922,9 @@ class LongformerModel(LongformerPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
...@@ -938,7 +964,9 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -938,7 +964,9 @@ class LongformerModel(LongformerPreTrainedModel):
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full( input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len), self.config.pad_token_id, dtype=torch.long, (batch_size, padding_len),
self.config.pad_token_id,
dtype=torch.long,
) )
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
...@@ -976,28 +1004,28 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -976,28 +1004,28 @@ class LongformerModel(LongformerPreTrainedModel):
): ):
r""" r"""
Returns: Returns:
Examples:: Examples::
>>> import torch >>> import torch
>>> from transformers import LongformerModel, LongformerTokenizer >>> from transformers import LongformerModel, LongformerTokenizer
>>> model = LongformerModel.from_pretrained('allenai/longformer-base-4096', return_dict=True) >>> model = LongformerModel.from_pretrained('allenai/longformer-base-4096', return_dict=True)
>>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') >>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
>>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document >>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document
>>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
>>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention >>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
>>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention >>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
>>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example, >>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example,
... # classification: the <s> token ... # classification: the <s> token
... # QA: question tokens ... # QA: question tokens
... # LM: potentially on the beginning of sentences and paragraphs ... # LM: potentially on the beginning of sentences and paragraphs
>>> outputs = model(input_ids, attention_mask=attention_mask) >>> outputs = model(input_ids, attention_mask=attention_mask)
>>> sequence_output = outputs.last_hidden_state >>> sequence_output = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output >>> pooled_output = outputs.pooler_output
""" """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
...@@ -1102,32 +1130,32 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1102,32 +1130,32 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
**kwargs **kwargs
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated. Used to hide legacy arguments that have been deprecated.
Returns: Returns:
Examples:: Examples::
>>> import torch >>> import torch
>>> from transformers import LongformerForMaskedLM, LongformerTokenizer >>> from transformers import LongformerForMaskedLM, LongformerTokenizer
>>> model = LongformerForMaskedLM.from_pretrained('allenai/longformer-base-4096', return_dict=True) >>> model = LongformerForMaskedLM.from_pretrained('allenai/longformer-base-4096', return_dict=True)
>>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') >>> tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
>>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document >>> SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000) # long input document
>>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1 >>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
>>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM >>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM
... # check ``LongformerModel.forward`` for more details how to set `attention_mask` ... # check ``LongformerModel.forward`` for more details how to set `attention_mask`
>>> outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) >>> outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
>>> loss = outputs.loss >>> loss = outputs.loss
>>> prediction_logits = output.logits >>> prediction_logits = output.logits
""" """
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
...@@ -1252,7 +1280,10 @@ class LongformerForSequenceClassification(BertPreTrainedModel): ...@@ -1252,7 +1280,10 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1310,39 +1341,39 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): ...@@ -1310,39 +1341,39 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the start of the labelled span for computing the token classification loss. Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for position (index) of the end of the labelled span for computing the token classification loss. Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss. Position outside of the sequence are not taken into account for computing the loss.
Returns: Returns:
Examples:: Examples::
>>> from transformers import LongformerTokenizer, LongformerForQuestionAnswering >>> from transformers import LongformerTokenizer, LongformerForQuestionAnswering
>>> import torch >>> import torch
>>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa") >>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa")
>>> model = LongformerForQuestionAnswering.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa", return_dict=True) >>> model = LongformerForQuestionAnswering.from_pretrained("allenai/longformer-large-4096-finetuned-triviaqa", return_dict=True)
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> encoding = tokenizer(question, text, return_tensors="pt") >>> encoding = tokenizer(question, text, return_tensors="pt")
>>> input_ids = encoding["input_ids"] >>> input_ids = encoding["input_ids"]
>>> # default is local attention everywhere >>> # default is local attention everywhere
>>> # the forward method will automatically set global attention on question tokens >>> # the forward method will automatically set global attention on question tokens
>>> attention_mask = encoding["attention_mask"] >>> attention_mask = encoding["attention_mask"]
>>> outputs = model(input_ids, attention_mask=attention_mask) >>> outputs = model(input_ids, attention_mask=attention_mask)
>>> start_logits = outputs.start_logits >>> start_logits = outputs.start_logits
>>> end_logits = outputs.end_logits >>> end_logits = outputs.end_logits
>>> all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) >>> all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
>>> answer_tokens = all_tokens[torch.argmax(start_logits) :torch.argmax(end_logits)+1] >>> answer_tokens = all_tokens[torch.argmax(start_logits) :torch.argmax(end_logits)+1]
>>> answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token >>> answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1487,7 +1518,10 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -1487,7 +1518,10 @@ class LongformerForTokenClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1592,5 +1626,8 @@ class LongformerForMultipleChoice(BertPreTrainedModel): ...@@ -1592,5 +1626,8 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -32,8 +32,7 @@ _CONFIG_FOR_DOC = "MMBTConfig" ...@@ -32,8 +32,7 @@ _CONFIG_FOR_DOC = "MMBTConfig"
class ModalEmbeddings(nn.Module): class ModalEmbeddings(nn.Module):
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding. """Generic Modal Embeddings which takes in an encoder, and a transformer embedding."""
"""
def __init__(self, config, encoder, embeddings): def __init__(self, config, encoder, embeddings):
super().__init__() super().__init__()
...@@ -154,7 +153,8 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs: ...@@ -154,7 +153,8 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs:
@add_start_docstrings( @add_start_docstrings(
"The bare MMBT Model outputting raw hidden-states without any specific head on top.", MMBT_START_DOCSTRING, "The bare MMBT Model outputting raw hidden-states without any specific head on top.",
MMBT_START_DOCSTRING,
) )
class MMBTModel(nn.Module, ModuleUtilsMixin): class MMBTModel(nn.Module, ModuleUtilsMixin):
def __init__(self, config, transformer, encoder): def __init__(self, config, transformer, encoder):
...@@ -288,34 +288,34 @@ class MMBTModel(nn.Module, ModuleUtilsMixin): ...@@ -288,34 +288,34 @@ class MMBTModel(nn.Module, ModuleUtilsMixin):
) )
class MMBTForClassification(nn.Module): class MMBTForClassification(nn.Module):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss. Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``. Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss. Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax). Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``output_hidden_states=True``) **hidden_states**: (`optional`, returned when ``output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``: of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``output_attentions=True``) **attentions**: (`optional`, returned when ``output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples:: Examples::
# For example purposes. Not runnable. # For example purposes. Not runnable.
transformer = BertModel.from_pretrained('bert-base-uncased') transformer = BertModel.from_pretrained('bert-base-uncased')
encoder = ImageEncoder(args) encoder = ImageEncoder(args)
model = MMBTForClassification(config, transformer, encoder) model = MMBTForClassification(config, transformer, encoder)
outputs = model(input_modal, input_ids, labels=labels) outputs = model(input_modal, input_ids, labels=labels)
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
def __init__(self, config, transformer, encoder): def __init__(self, config, transformer, encoder):
super().__init__() super().__init__()
...@@ -378,5 +378,8 @@ class MMBTForClassification(nn.Module): ...@@ -378,5 +378,8 @@ class MMBTForClassification(nn.Module):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -64,8 +64,7 @@ MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"] ...@@ -64,8 +64,7 @@ MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]
def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model. """Load tf checkpoints in a pytorch model."""
"""
try: try:
import re import re
...@@ -161,8 +160,7 @@ NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm} ...@@ -161,8 +160,7 @@ NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm}
class MobileBertEmbeddings(nn.Module): class MobileBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -663,8 +661,8 @@ class MobileBertPreTrainingHeads(nn.Module): ...@@ -663,8 +661,8 @@ class MobileBertPreTrainingHeads(nn.Module):
class MobileBertPreTrainedModel(PreTrainedModel): class MobileBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = MobileBertConfig config_class = MobileBertConfig
...@@ -788,7 +786,7 @@ MOBILEBERT_INPUTS_DOCSTRING = r""" ...@@ -788,7 +786,7 @@ MOBILEBERT_INPUTS_DOCSTRING = r"""
) )
class MobileBertModel(MobileBertPreTrainedModel): class MobileBertModel(MobileBertPreTrainedModel):
""" """
https://arxiv.org/pdf/2004.02984.pdf https://arxiv.org/pdf/2004.02984.pdf
""" """
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
...@@ -809,9 +807,9 @@ class MobileBertModel(MobileBertPreTrainedModel): ...@@ -809,9 +807,9 @@ class MobileBertModel(MobileBertPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
...@@ -965,31 +963,31 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): ...@@ -965,31 +963,31 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`): labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`): next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring) Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
Indices should be in ``[0, 1]``. Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A, ``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence. ``1`` indicates sequence B is a random sequence.
Returns: Returns:
Examples:: Examples::
>>> from transformers import MobileBertTokenizer, MobileBertForPreTraining >>> from transformers import MobileBertTokenizer, MobileBertForPreTraining
>>> import torch >>> import torch
>>> tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased") >>> tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
>>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased", return_dict=True) >>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased", return_dict=True)
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids) >>> outputs = model(input_ids)
>>> prediction_logits = outptus.prediction_logits >>> prediction_logits = outptus.prediction_logits
>>> seq_relationship_logits = outputs.seq_relationship_logits >>> seq_relationship_logits = outputs.seq_relationship_logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1176,29 +1174,29 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel): ...@@ -1176,29 +1174,29 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
Indices should be in ``[0, 1]``. Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A, ``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence. ``1`` indicates sequence B is a random sequence.
Returns: Returns:
Examples:: Examples::
>>> from transformers import MobileBertTokenizer, MobileBertForNextSentencePrediction >>> from transformers import MobileBertTokenizer, MobileBertForNextSentencePrediction
>>> import torch >>> import torch
>>> tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased') >>> tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
>>> model = MobileBertForNextSentencePrediction.from_pretrained('google/mobilebert-uncased', return_dict=True) >>> model = MobileBertForNextSentencePrediction.from_pretrained('google/mobilebert-uncased', return_dict=True)
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt') >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
>>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1])) >>> outputs = model(**encoding, next_sentence_label=torch.LongTensor([1]))
>>> loss = outputs.loss >>> loss = outputs.loss
>>> logits = outputs.logits >>> logits = outputs.logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1308,7 +1306,10 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): ...@@ -1308,7 +1306,10 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1491,7 +1492,10 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): ...@@ -1491,7 +1492,10 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1574,5 +1578,8 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel): ...@@ -1574,5 +1578,8 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -59,8 +59,7 @@ OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -59,8 +59,7 @@ OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here) """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
"""
import re import re
import numpy as np import numpy as np
...@@ -257,7 +256,10 @@ class Block(nn.Module): ...@@ -257,7 +256,10 @@ class Block(nn.Module):
def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
attn_outputs = self.attn( attn_outputs = self.attn(
x, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, x,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
) )
a = attn_outputs[0] a = attn_outputs[0]
...@@ -270,8 +272,8 @@ class Block(nn.Module): ...@@ -270,8 +272,8 @@ class Block(nn.Module):
class OpenAIGPTPreTrainedModel(PreTrainedModel): class OpenAIGPTPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
...@@ -280,8 +282,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): ...@@ -280,8 +282,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -408,8 +409,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -408,8 +409,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.tokens_embed = new_embeddings self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
...@@ -506,7 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -506,7 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
) )
...@@ -632,41 +635,41 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -632,41 +635,41 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
**kwargs **kwargs
): ):
r""" r"""
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input) mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
Index of the classification token in each input sequence. Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1]``. Selected in the range ``[0, input_ids.size(-1) - 1]``.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`) labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-100`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`) mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above) of the input tensors. (see `input_ids` above)
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`): kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
Used to hide legacy arguments that have been deprecated. Used to hide legacy arguments that have been deprecated.
Return: Return:
Examples:: Examples::
from transformers import OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel from transformers import OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel
import torch import torch
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt', return_dict=True) model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt', return_dict=True)
tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!) tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!)
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
mc_token_ids = torch.tensor([input_ids.size(-1)-1, input_ids.size(-1)-1]).unsqueeze(0) # Batch size 1 mc_token_ids = torch.tensor([input_ids.size(-1)-1, input_ids.size(-1)-1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, mc_token_ids=mc_token_ids) outputs = model(input_ids, mc_token_ids=mc_token_ids)
lm_logits = outputs.lm_logits lm_logits = outputs.lm_logits
mc_logits = outputs.mc_logits mc_logits = outputs.mc_logits
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if "lm_labels" in kwargs: if "lm_labels" in kwargs:
......
...@@ -78,7 +78,8 @@ ReformerBackwardOutput = namedtuple( ...@@ -78,7 +78,8 @@ ReformerBackwardOutput = namedtuple(
"ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"]
) )
ReformerEncoderOutput = namedtuple( ReformerEncoderOutput = namedtuple(
"ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], "ReformerEncoderOutput",
["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
) )
...@@ -192,7 +193,9 @@ class AxialPositionEmbeddings(nn.Module): ...@@ -192,7 +193,9 @@ class AxialPositionEmbeddings(nn.Module):
assert ( assert (
reduce(mul, self.axial_pos_shape) >= sequence_length reduce(mul, self.axial_pos_shape) >= sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format( ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format(
self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length, self.axial_pos_shape,
sequence_length,
self.least_common_mult_chunk_length,
) )
# compute how many columns are needed # compute how many columns are needed
...@@ -218,8 +221,7 @@ class AxialPositionEmbeddings(nn.Module): ...@@ -218,8 +221,7 @@ class AxialPositionEmbeddings(nn.Module):
class PositionEmbeddings(nn.Module): class PositionEmbeddings(nn.Module):
"""Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`. """Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -233,8 +235,7 @@ class PositionEmbeddings(nn.Module): ...@@ -233,8 +235,7 @@ class PositionEmbeddings(nn.Module):
class ReformerEmbeddings(nn.Module): class ReformerEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -285,16 +286,16 @@ class EfficientAttentionMixin: ...@@ -285,16 +286,16 @@ class EfficientAttentionMixin:
""" """
def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after):
""" Used to implement attention between consecutive chunks. """Used to implement attention between consecutive chunks.
Args: Args:
vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...]
num_chunks_before: chunks before current chunk to include in attention num_chunks_before: chunks before current chunk to include in attention
num_chunks_after: chunks after current chunk to include in attention num_chunks_after: chunks after current chunk to include in attention
Returns: Returns:
tensor of shape [num_chunks, N * chunk_length, ...], where tensor of shape [num_chunks, N * chunk_length, ...], where
N = (1 + num_chunks_before + num_chunks_after). N = (1 + num_chunks_before + num_chunks_after).
""" """
if num_chunks_before == 0 and num_chunks_after == 0: if num_chunks_before == 0 and num_chunks_after == 0:
return vectors return vectors
...@@ -309,7 +310,7 @@ class EfficientAttentionMixin: ...@@ -309,7 +310,7 @@ class EfficientAttentionMixin:
def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size): def _split_hidden_size_dim(self, x, num_attn_heads, attn_head_size):
""" """
splits hidden_size dim into attn_head_size and num_attn_heads splits hidden_size dim into attn_head_size and num_attn_heads
""" """
new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size) new_x_shape = x.size()[:-1] + (num_attn_heads, attn_head_size)
x = x.view(*new_x_shape) x = x.view(*new_x_shape)
...@@ -317,14 +318,14 @@ class EfficientAttentionMixin: ...@@ -317,14 +318,14 @@ class EfficientAttentionMixin:
def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size): def _merge_hidden_size_dims(self, x, num_attn_heads, attn_head_size):
""" """
merges attn_head_size dim and num_attn_heads dim into hidden_size merges attn_head_size dim and num_attn_heads dim into hidden_size
""" """
x = x.permute(0, 2, 1, 3) x = x.permute(0, 2, 1, 3)
return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size)) return torch.reshape(x, (x.size()[0], -1, num_attn_heads * attn_head_size))
def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None): def _split_seq_length_dim_to(self, vectors, dim_factor_1, dim_factor_2, num_attn_heads, attn_head_size=None):
""" """
splits sequence length dim of vectors into `dim_factor_1` and `dim_factor_2` dims splits sequence length dim of vectors into `dim_factor_1` and `dim_factor_2` dims
""" """
batch_size = vectors.shape[0] batch_size = vectors.shape[0]
split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2) split_dim_shape = (batch_size, num_attn_heads, dim_factor_1, dim_factor_2)
...@@ -418,10 +419,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -418,10 +419,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# split key & value vectors by num hashes to apply # split key & value vectors by num hashes to apply
# self attention on each separately # self attention on each separately
query_key_vectors = self._split_seq_length_dim_to( query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, query_key_vectors,
num_hashes,
-1,
self.num_attention_heads,
self.attention_head_size,
) )
value_vectors = self._split_seq_length_dim_to( value_vectors = self._split_seq_length_dim_to(
value_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, value_vectors,
num_hashes,
-1,
self.num_attention_heads,
self.attention_head_size,
) )
# repeat query vectors across hash dimension # repeat query vectors across hash dimension
query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1) query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1)
...@@ -496,10 +505,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -496,10 +505,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes) query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes)
query_key_vectors = self._split_seq_length_dim_to( query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, query_key_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
value_vectors = self._split_seq_length_dim_to( value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, value_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
if self.chunk_length is None: if self.chunk_length is None:
...@@ -548,10 +565,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -548,10 +565,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# sum up all hash rounds # sum up all hash rounds
if num_hashes > 1: if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to( out_vectors = self._split_seq_length_dim_to(
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, out_vectors,
num_hashes,
sequence_length,
self.num_attention_heads,
self.attention_head_size,
) )
logits = self._split_seq_length_dim_to( logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, logits,
num_hashes,
sequence_length,
self.num_attention_heads,
self.attention_head_size,
).unsqueeze(-1) ).unsqueeze(-1)
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
...@@ -697,7 +722,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -697,7 +722,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# factorize `num_buckets` if `num_buckets` becomes too large # factorize `num_buckets` if `num_buckets` becomes too large
num_buckets_limit = 2 * max( num_buckets_limit = 2 * max(
int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length, int((self.max_position_embeddings // self.chunk_length) ** (0.5)),
self.chunk_length,
) )
if num_buckets > num_buckets_limit: if num_buckets > num_buckets_limit:
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)] num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
...@@ -946,7 +972,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -946,7 +972,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def _len_and_dim_norm(self, vectors): def _len_and_dim_norm(self, vectors):
""" """
length and attention head size dim normalization length and attention head size dim normalization
""" """
vectors = self._len_norm(vectors) vectors = self._len_norm(vectors)
vectors = vectors * torch.rsqrt( vectors = vectors * torch.rsqrt(
...@@ -956,7 +982,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -956,7 +982,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def _len_norm(self, x, epsilon=1e-6): def _len_norm(self, x, epsilon=1e-6):
""" """
length normalization length normalization
""" """
variance = torch.mean(x ** 2, -1, keepdim=True) variance = torch.mean(x ** 2, -1, keepdim=True)
norm_x = x * torch.rsqrt(variance + epsilon) norm_x = x * torch.rsqrt(variance + epsilon)
...@@ -964,7 +990,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -964,7 +990,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def _gather_by_expansion(self, vectors, idxs, num_hashes): def _gather_by_expansion(self, vectors, idxs, num_hashes):
""" """
expand dims of idxs and vectors for all hashes and gather expand dims of idxs and vectors for all hashes and gather
""" """
expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size) expanded_idxs = idxs.unsqueeze(-1).expand(-1, -1, -1, self.attention_head_size)
vectors = vectors.repeat(1, 1, num_hashes, 1) vectors = vectors.repeat(1, 1, num_hashes, 1)
...@@ -973,11 +999,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -973,11 +999,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
class ReverseSort(Function): class ReverseSort(Function):
""" """
After chunked attention is applied which sorted clusters, After chunked attention is applied which sorted clusters,
original ordering has to be restored. original ordering has to be restored.
Since customized backward function is used for Reformer, Since customized backward function is used for Reformer,
the gradients of the output vectors have to be explicitely the gradients of the output vectors have to be explicitely
sorted here. sorted here.
""" """
@staticmethod @staticmethod
...@@ -1113,13 +1139,25 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -1113,13 +1139,25 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
# chunk vectors # chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to( query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, query_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
key_vectors = self._split_seq_length_dim_to( key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, key_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
value_vectors = self._split_seq_length_dim_to( value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, value_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
# chunk indices # chunk indices
...@@ -1179,7 +1217,12 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -1179,7 +1217,12 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
if not do_standard_self_attention: if not do_standard_self_attention:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) assert out_vectors.shape == (
batch_size,
self.num_attention_heads,
sequence_length,
self.attention_head_size,
)
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
...@@ -1321,7 +1364,9 @@ class ReformerAttention(nn.Module): ...@@ -1321,7 +1364,9 @@ class ReformerAttention(nn.Module):
attention_output = self.output(self_attention_outputs.hidden_states) attention_output = self.output(self_attention_outputs.hidden_states)
return AttentionOutput( return AttentionOutput(
hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets, hidden_states=attention_output,
attention_probs=self_attention_outputs.attention_probs,
buckets=buckets,
) )
...@@ -1369,7 +1414,10 @@ class ChunkReformerFeedForward(nn.Module): ...@@ -1369,7 +1414,10 @@ class ChunkReformerFeedForward(nn.Module):
def forward(self, attention_output): def forward(self, attention_output):
return apply_chunking_to_forward( return apply_chunking_to_forward(
self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, self.forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
) )
def forward_chunk(self, hidden_states): def forward_chunk(self, hidden_states):
...@@ -1391,11 +1439,11 @@ class ReformerLayer(nn.Module): ...@@ -1391,11 +1439,11 @@ class ReformerLayer(nn.Module):
def _init_attention_seed(self): def _init_attention_seed(self):
""" """
This function sets a new seed for the This function sets a new seed for the
attention layer to make dropout deterministic attention layer to make dropout deterministic
for both forward calls: 1 normal forward for both forward calls: 1 normal forward
call and 1 forward call in backward call and 1 forward call in backward
to recalculate activations. to recalculate activations.
""" """
# randomize seeds # randomize seeds
...@@ -1412,11 +1460,11 @@ class ReformerLayer(nn.Module): ...@@ -1412,11 +1460,11 @@ class ReformerLayer(nn.Module):
def _init_feed_forward_seed(self): def _init_feed_forward_seed(self):
""" """
This function sets a new seed for the This function sets a new seed for the
feed forward layer to make dropout deterministic feed forward layer to make dropout deterministic
for both forward calls: 1 normal forward for both forward calls: 1 normal forward
call and 1 forward call in backward call and 1 forward call in backward
to recalculate activations. to recalculate activations.
""" """
# randomize seeds # randomize seeds
# use cuda generator if available # use cuda generator if available
...@@ -1520,7 +1568,10 @@ class ReformerLayer(nn.Module): ...@@ -1520,7 +1568,10 @@ class ReformerLayer(nn.Module):
# f(X_2) # f(X_2)
# use cached buckets for backprob if buckets not None for LSHSelfAttention # use cached buckets for backprob if buckets not None for LSHSelfAttention
output = self.attention( output = self.attention(
hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets, hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
buckets=buckets,
).hidden_states ).hidden_states
output.backward(grad_attn_output, retain_graph=True) output.backward(grad_attn_output, retain_graph=True)
...@@ -1738,8 +1789,8 @@ class ReformerOnlyLMHead(nn.Module): ...@@ -1738,8 +1789,8 @@ class ReformerOnlyLMHead(nn.Module):
class ReformerPreTrainedModel(PreTrainedModel): class ReformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
config_class = ReformerConfig config_class = ReformerConfig
...@@ -1947,9 +1998,9 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1947,9 +1998,9 @@ class ReformerModel(ReformerPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
...@@ -2099,7 +2150,10 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -2099,7 +2150,10 @@ class ReformerModel(ReformerPreTrainedModel):
) )
padded_input_ids = torch.full( padded_input_ids = torch.full(
(input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, (input_shape[0], padding_length),
self.config.pad_token_id,
device=device,
dtype=torch.long,
) )
# Extend `attention_mask` # Extend `attention_mask`
...@@ -2369,11 +2423,11 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2369,11 +2423,11 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
return_dict=None, return_dict=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the sequence classification/regression loss. Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -2407,7 +2461,10 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2407,7 +2461,10 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment