"vscode:/vscode.git/clone" did not exist on "fb560dcb075497f61880010245192e7e1fdbeca4"
Commit a75c64d8 authored by Lysandre's avatar Lysandre
Browse files

Black 20 release

parent e78c1103
...@@ -329,7 +329,13 @@ class TFGenerationMixin: ...@@ -329,7 +329,13 @@ class TFGenerationMixin:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# create empty decoder_input_ids # create empty decoder_input_ids
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id input_ids = (
tf.ones(
(effective_batch_size * num_beams, 1),
dtype=tf.int32,
)
* decoder_start_token_id
)
cur_len = 1 cur_len = 1
assert ( assert (
...@@ -422,7 +428,7 @@ class TFGenerationMixin: ...@@ -422,7 +428,7 @@ class TFGenerationMixin:
attention_mask, attention_mask,
use_cache, use_cache,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
""" """
...@@ -587,8 +593,7 @@ class TFGenerationMixin: ...@@ -587,8 +593,7 @@ class TFGenerationMixin:
attention_mask, attention_mask,
use_cache, use_cache,
): ):
""" Generate sequences for each example with beam search. """Generate sequences for each example with beam search."""
"""
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
...@@ -960,7 +965,7 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): ...@@ -960,7 +965,7 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size, vocabulary size) logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
...@@ -1001,7 +1006,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ...@@ -1001,7 +1006,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1) sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = tf.concat( sorted_indices_to_remove = tf.concat(
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1, [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]],
-1,
) )
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices) indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
......
...@@ -83,7 +83,11 @@ class GenerationMixin: ...@@ -83,7 +83,11 @@ class GenerationMixin:
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
self.enforce_repetition_penalty_( self.enforce_repetition_penalty_(
scores, batch_size, num_beams, input_ids, repetition_penalty, scores,
batch_size,
num_beams,
input_ids,
repetition_penalty,
) )
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
...@@ -324,7 +328,10 @@ class GenerationMixin: ...@@ -324,7 +328,10 @@ class GenerationMixin:
"or a `bos_token_id` (integer >= 0) as a first token to start the generation." "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
) )
input_ids = torch.full( input_ids = torch.full(
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device, (batch_size, 1),
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
) )
else: else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)." assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
...@@ -514,7 +521,7 @@ class GenerationMixin: ...@@ -514,7 +521,7 @@ class GenerationMixin:
use_cache, use_cache,
model_specific_kwargs, model_specific_kwargs,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
""" """
# length of generated sentences / unfinished sentences # length of generated sentences / unfinished sentences
...@@ -619,8 +626,7 @@ class GenerationMixin: ...@@ -619,8 +626,7 @@ class GenerationMixin:
use_cache, use_cache,
model_specific_kwargs, model_specific_kwargs,
): ):
""" Generate sequences for each example with beam search. """Generate sequences for each example with beam search."""
"""
# generated hypotheses # generated hypotheses
generated_hyps = [ generated_hyps = [
...@@ -749,7 +755,8 @@ class GenerationMixin: ...@@ -749,7 +755,8 @@ class GenerationMixin:
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
continue continue
generated_hyps[batch_idx].add( generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), beam_token_score.item(), input_ids[effective_beam_id].clone(),
beam_token_score.item(),
) )
else: else:
# add next predicted token since it is not eos_token # add next predicted token since it is not eos_token
...@@ -806,7 +813,8 @@ class GenerationMixin: ...@@ -806,7 +813,8 @@ class GenerationMixin:
assert torch.all( assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx] next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format( ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx], next_scores[:, :num_beams][batch_idx],
beam_scores.view(batch_size, num_beams)[batch_idx],
) )
# need to add best num_beams hypotheses to generated hyps # need to add best num_beams hypotheses to generated hyps
...@@ -916,7 +924,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter ...@@ -916,7 +924,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None: def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...] a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args: Args:
scores: logits distribution of shape (batch size, vocabulary size) scores: logits distribution of shape (batch size, vocabulary size)
...@@ -946,7 +954,7 @@ def top_k_top_p_filtering( ...@@ -946,7 +954,7 @@ def top_k_top_p_filtering(
filter_value: float = -float("Inf"), filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1, min_tokens_to_keep: int = 1,
) -> Tensor: ) -> Tensor:
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args: Args:
logits: logits distribution shape (batch size, vocabulary size) logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering). if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
......
...@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) ...@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__)
class ModelCard: class ModelCard:
r""" Structured Model Card class. r"""Structured Model Card class.
Store model card as well as methods for loading/downloading/saving model cards. Store model card as well as methods for loading/downloading/saving model cards.
Please read the following paper for details and explanation on the sections: Please read the following paper for details and explanation on the sections:
...@@ -73,8 +73,7 @@ class ModelCard: ...@@ -73,8 +73,7 @@ class ModelCard:
raise err raise err
def save_pretrained(self, save_directory_or_file): def save_pretrained(self, save_directory_or_file):
""" Save a model card object to the directory or file `save_directory_or_file`. """Save a model card object to the directory or file `save_directory_or_file`."""
"""
if os.path.isdir(save_directory_or_file): if os.path.isdir(save_directory_or_file):
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME) output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
...@@ -86,7 +85,7 @@ class ModelCard: ...@@ -86,7 +85,7 @@ class ModelCard:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card. r"""Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
Parameters: Parameters:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
......
...@@ -302,7 +302,10 @@ class AlbertLayer(nn.Module): ...@@ -302,7 +302,10 @@ class AlbertLayer(nn.Module):
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = apply_chunking_to_forward( ffn_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0], self.ff_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[0],
) )
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0]) hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
...@@ -397,7 +400,7 @@ class AlbertTransformer(nn.Module): ...@@ -397,7 +400,7 @@ class AlbertTransformer(nn.Module):
class AlbertPreTrainedModel(PreTrainedModel): class AlbertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -406,8 +409,7 @@ class AlbertPreTrainedModel(PreTrainedModel): ...@@ -406,8 +409,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -543,7 +545,7 @@ class AlbertModel(AlbertPreTrainedModel): ...@@ -543,7 +545,7 @@ class AlbertModel(AlbertPreTrainedModel):
return self.embeddings.word_embeddings return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups. ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups.
If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there
...@@ -787,7 +789,8 @@ class AlbertSOPHead(nn.Module): ...@@ -787,7 +789,8 @@ class AlbertSOPHead(nn.Module):
@add_start_docstrings( @add_start_docstrings(
"Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, "Albert Model with a `language modeling` head on top.",
ALBERT_START_DOCSTRING,
) )
class AlbertForMaskedLM(AlbertPreTrainedModel): class AlbertForMaskedLM(AlbertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
...@@ -952,7 +955,10 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -952,7 +955,10 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1033,7 +1039,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ...@@ -1033,7 +1039,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1215,5 +1224,8 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel): ...@@ -1215,5 +1224,8 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -391,7 +391,7 @@ class AutoModel: ...@@ -391,7 +391,7 @@ class AutoModel:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -433,7 +433,7 @@ class AutoModel: ...@@ -433,7 +433,7 @@ class AutoModel:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -547,7 +547,7 @@ class AutoModelForPreTraining: ...@@ -547,7 +547,7 @@ class AutoModelForPreTraining:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -589,7 +589,7 @@ class AutoModelForPreTraining: ...@@ -589,7 +589,7 @@ class AutoModelForPreTraining:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration. r"""Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
based on the `model_type` property of the config object, or when it's missing, based on the `model_type` property of the config object, or when it's missing,
...@@ -697,7 +697,7 @@ class AutoModelWithLMHead: ...@@ -697,7 +697,7 @@ class AutoModelWithLMHead:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -743,7 +743,7 @@ class AutoModelWithLMHead: ...@@ -743,7 +743,7 @@ class AutoModelWithLMHead:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -856,7 +856,7 @@ class AutoModelForCausalLM: ...@@ -856,7 +856,7 @@ class AutoModelForCausalLM:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -893,7 +893,7 @@ class AutoModelForCausalLM: ...@@ -893,7 +893,7 @@ class AutoModelForCausalLM:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -993,7 +993,7 @@ class AutoModelForMaskedLM: ...@@ -993,7 +993,7 @@ class AutoModelForMaskedLM:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1033,7 +1033,7 @@ class AutoModelForMaskedLM: ...@@ -1033,7 +1033,7 @@ class AutoModelForMaskedLM:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1136,7 +1136,7 @@ class AutoModelForSeq2SeqLM: ...@@ -1136,7 +1136,7 @@ class AutoModelForSeq2SeqLM:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1172,7 +1172,7 @@ class AutoModelForSeq2SeqLM: ...@@ -1172,7 +1172,7 @@ class AutoModelForSeq2SeqLM:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1271,7 +1271,7 @@ class AutoModelForSequenceClassification: ...@@ -1271,7 +1271,7 @@ class AutoModelForSequenceClassification:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1313,7 +1313,7 @@ class AutoModelForSequenceClassification: ...@@ -1313,7 +1313,7 @@ class AutoModelForSequenceClassification:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the sequence classification model classes of the library r"""Instantiates one of the sequence classification model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1423,7 +1423,7 @@ class AutoModelForQuestionAnswering: ...@@ -1423,7 +1423,7 @@ class AutoModelForQuestionAnswering:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1462,7 +1462,7 @@ class AutoModelForQuestionAnswering: ...@@ -1462,7 +1462,7 @@ class AutoModelForQuestionAnswering:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
...@@ -1568,7 +1568,7 @@ class AutoModelForTokenClassification: ...@@ -1568,7 +1568,7 @@ class AutoModelForTokenClassification:
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
r""" Instantiates one of the base model classes of the library r"""Instantiates one of the base model classes of the library
from a configuration. from a configuration.
Note: Note:
...@@ -1611,7 +1611,7 @@ class AutoModelForTokenClassification: ...@@ -1611,7 +1611,7 @@ class AutoModelForTokenClassification:
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration. from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance The `from_pretrained()` method takes care of returning the correct model class instance
......
...@@ -223,7 +223,9 @@ class EncoderLayer(nn.Module): ...@@ -223,7 +223,9 @@ class EncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = SelfAttention( self.self_attn = SelfAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, self.embed_dim,
config.encoder_attention_heads,
dropout=config.attention_dropout,
) )
self.normalize_before = config.normalize_before self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
...@@ -297,7 +299,10 @@ class BartEncoder(nn.Module): ...@@ -297,7 +299,10 @@ class BartEncoder(nn.Module):
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings, config.max_position_embeddings,
embed_dim,
self.padding_idx,
config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
...@@ -370,7 +375,9 @@ class DecoderLayer(nn.Module): ...@@ -370,7 +375,9 @@ class DecoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.d_model self.embed_dim = config.d_model
self.self_attn = SelfAttention( self.self_attn = SelfAttention(
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
) )
self.dropout = config.dropout self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function] self.activation_fn = ACT2FN[config.activation_function]
...@@ -477,7 +484,10 @@ class BartDecoder(nn.Module): ...@@ -477,7 +484,10 @@ class BartDecoder(nn.Module):
) )
else: else:
self.embed_positions = LearnedPositionalEmbedding( self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings, config.max_position_embeddings,
config.d_model,
self.padding_idx,
config.extra_pos_embeddings,
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)] [DecoderLayer(config) for _ in range(config.decoder_layers)]
...@@ -695,7 +705,10 @@ class SelfAttention(nn.Module): ...@@ -695,7 +705,10 @@ class SelfAttention(nn.Module):
# This is part of a workaround to get around fork/join parallelism not supporting Optional types. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0: if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,) assert key_padding_mask is None or key_padding_mask.size()[:2] == (
bsz,
src_len,
)
if key_padding_mask is not None: # don't attend to padding symbols if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
...@@ -703,7 +716,11 @@ class SelfAttention(nn.Module): ...@@ -703,7 +716,11 @@ class SelfAttention(nn.Module):
attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,) attn_probs = F.dropout(
attn_weights,
p=self.dropout,
training=self.training,
)
assert v is not None assert v is not None
attn_output = torch.bmm(attn_probs, v) attn_output = torch.bmm(attn_probs, v)
...@@ -754,7 +771,11 @@ class BartClassificationHead(nn.Module): ...@@ -754,7 +771,11 @@ class BartClassificationHead(nn.Module):
# This can trivially be shared with RobertaClassificationHead # This can trivially be shared with RobertaClassificationHead
def __init__( def __init__(
self, input_dim, inner_dim, num_classes, pooler_dropout, self,
input_dim,
inner_dim,
num_classes,
pooler_dropout,
): ):
super().__init__() super().__init__()
self.dense = nn.Linear(input_dim, inner_dim) self.dense = nn.Linear(input_dim, inner_dim)
...@@ -819,7 +840,8 @@ def _get_shape(t): ...@@ -819,7 +840,8 @@ def _get_shape(t):
@add_start_docstrings( @add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, "The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
) )
class BartModel(PretrainedBartModel): class BartModel(PretrainedBartModel):
def __init__(self, config: BartConfig): def __init__(self, config: BartConfig):
...@@ -1116,7 +1138,10 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1116,7 +1138,10 @@ class BartForSequenceClassification(PretrainedBartModel):
super().__init__(config, **kwargs) super().__init__(config, **kwargs)
self.model = BartModel(config) self.model = BartModel(config)
self.classification_head = BartClassificationHead( self.classification_head = BartClassificationHead(
config.d_model, config.d_model, config.num_labels, config.classif_dropout, config.d_model,
config.d_model,
config.num_labels,
config.classif_dropout,
) )
self.model._init_weights(self.classification_head.dense) self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj) self.model._init_weights(self.classification_head.out_proj)
...@@ -1279,7 +1304,10 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1279,7 +1304,10 @@ class BartForQuestionAnswering(PretrainedBartModel):
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
if not return_dict: if not return_dict:
output = (start_logits, end_logits,) + outputs[1:] output = (
start_logits,
end_logits,
) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output return ((total_loss,) + output) if total_loss is not None else output
return Seq2SeqQuestionAnsweringModelOutput( return Seq2SeqQuestionAnsweringModelOutput(
......
...@@ -89,8 +89,7 @@ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -89,8 +89,7 @@ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_bert(model, config, tf_checkpoint_path): def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model. """Load tf checkpoints in a pytorch model."""
"""
try: try:
import re import re
...@@ -174,8 +173,7 @@ BertLayerNorm = torch.nn.LayerNorm ...@@ -174,8 +173,7 @@ BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module): class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -343,7 +341,12 @@ class BertAttention(nn.Module): ...@@ -343,7 +341,12 @@ class BertAttention(nn.Module):
output_attentions=False, output_attentions=False,
): ):
self_outputs = self.self( self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions, hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
) )
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
...@@ -403,7 +406,10 @@ class BertLayer(nn.Module): ...@@ -403,7 +406,10 @@ class BertLayer(nn.Module):
output_attentions=False, output_attentions=False,
): ):
self_attention_outputs = self.attention( self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions=output_attentions, hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
) )
attention_output = self_attention_outputs[0] attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
...@@ -582,7 +588,7 @@ class BertPreTrainingHeads(nn.Module): ...@@ -582,7 +588,7 @@ class BertPreTrainingHeads(nn.Module):
class BertPreTrainedModel(PreTrainedModel): class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -733,7 +739,7 @@ class BertModel(BertPreTrainedModel): ...@@ -733,7 +739,7 @@ class BertModel(BertPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
...@@ -1049,7 +1055,10 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1049,7 +1055,10 @@ class BertLMHeadModel(BertPreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput( return CausalLMOutput(
loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
...@@ -1173,7 +1182,8 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1173,7 +1182,8 @@ class BertForMaskedLM(BertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING, """Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
) )
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
...@@ -1336,7 +1346,10 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1336,7 +1346,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1422,7 +1435,10 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1422,7 +1435,10 @@ class BertForMultipleChoice(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1505,7 +1521,10 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1505,7 +1521,10 @@ class BertForTokenClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
......
...@@ -68,7 +68,8 @@ class CamembertModel(RobertaModel): ...@@ -68,7 +68,8 @@ class CamembertModel(RobertaModel):
@add_start_docstrings( @add_start_docstrings(
"""CamemBERT Model with a `language modeling` head on top. """, CAMEMBERT_START_DOCSTRING, """CamemBERT Model with a `language modeling` head on top. """,
CAMEMBERT_START_DOCSTRING,
) )
class CamembertForMaskedLM(RobertaForMaskedLM): class CamembertForMaskedLM(RobertaForMaskedLM):
""" """
......
...@@ -212,7 +212,7 @@ class EncoderLayer(torch.nn.Module): ...@@ -212,7 +212,7 @@ class EncoderLayer(torch.nn.Module):
class CTRLPreTrainedModel(PreTrainedModel): class CTRLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -220,8 +220,7 @@ class CTRLPreTrainedModel(PreTrainedModel): ...@@ -220,8 +220,7 @@ class CTRLPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -331,7 +330,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -331,7 +330,7 @@ class CTRLModel(CTRLPreTrainedModel):
self.w = new_embeddings self.w = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
......
...@@ -261,7 +261,12 @@ class TransformerBlock(nn.Module): ...@@ -261,7 +261,12 @@ class TransformerBlock(nn.Module):
""" """
# Self-Attention # Self-Attention
sa_output = self.attention( sa_output = self.attention(
query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask, output_attentions=output_attentions, query=x,
key=x,
value=x,
mask=attn_mask,
head_mask=head_mask,
output_attentions=output_attentions,
) )
if output_attentions: if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
...@@ -343,7 +348,7 @@ class Transformer(nn.Module): ...@@ -343,7 +348,7 @@ class Transformer(nn.Module):
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # # INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class DistilBertPreTrainedModel(PreTrainedModel): class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -352,8 +357,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -352,8 +357,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, nn.Embedding): if isinstance(module, nn.Embedding):
if module.weight.requires_grad: if module.weight.requires_grad:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
...@@ -432,7 +436,7 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -432,7 +436,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings.word_embeddings = new_embeddings self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
...@@ -493,7 +497,8 @@ class DistilBertModel(DistilBertPreTrainedModel): ...@@ -493,7 +497,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
@add_start_docstrings( @add_start_docstrings(
"""DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING, """DistilBert Model with a `masked language modeling` head on top. """,
DISTILBERT_START_DOCSTRING,
) )
class DistilBertForMaskedLM(DistilBertPreTrainedModel): class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
...@@ -829,7 +834,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): ...@@ -829,7 +834,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -930,5 +938,8 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel): ...@@ -930,5 +938,8 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -265,7 +265,7 @@ class DPRSpanPredictor(PreTrainedModel): ...@@ -265,7 +265,7 @@ class DPRSpanPredictor(PreTrainedModel):
class DPRPretrainedContextEncoder(PreTrainedModel): class DPRPretrainedContextEncoder(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -278,7 +278,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel): ...@@ -278,7 +278,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
class DPRPretrainedQuestionEncoder(PreTrainedModel): class DPRPretrainedQuestionEncoder(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -291,7 +291,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): ...@@ -291,7 +291,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
class DPRPretrainedReader(PreTrainedModel): class DPRPretrainedReader(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -553,7 +553,8 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): ...@@ -553,7 +553,8 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
@add_start_docstrings( @add_start_docstrings(
"The bare DPRReader transformer outputting span predictions.", DPR_START_DOCSTRING, "The bare DPRReader transformer outputting span predictions.",
DPR_START_DOCSTRING,
) )
class DPRReader(DPRPretrainedReader): class DPRReader(DPRPretrainedReader):
def __init__(self, config: DPRConfig): def __init__(self, config: DPRConfig):
......
...@@ -46,8 +46,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -46,8 +46,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"): def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
""" Load tf checkpoints in a pytorch model. """Load tf checkpoints in a pytorch model."""
"""
try: try:
import re import re
...@@ -179,7 +178,7 @@ class ElectraGeneratorPredictions(nn.Module): ...@@ -179,7 +178,7 @@ class ElectraGeneratorPredictions(nn.Module):
class ElectraPreTrainedModel(BertPreTrainedModel): class ElectraPreTrainedModel(BertPreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -311,7 +310,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -311,7 +310,7 @@ class ElectraModel(ElectraPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
...@@ -836,7 +835,10 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel): ...@@ -836,7 +835,10 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
if not return_dict: if not return_dict:
output = (start_logits, end_logits,) + discriminator_hidden_states[1:] output = (
start_logits,
end_logits,
) + discriminator_hidden_states[1:]
return ((total_loss,) + output) if total_loss is not None else output return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput( return QuestionAnsweringModelOutput(
......
...@@ -103,7 +103,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -103,7 +103,7 @@ class EncoderDecoderModel(PreTrainedModel):
*model_args, *model_args,
**kwargs **kwargs
) -> PreTrainedModel: ) -> PreTrainedModel:
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints. r"""Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
......
...@@ -240,7 +240,11 @@ class FlaubertModel(XLMModel): ...@@ -240,7 +240,11 @@ class FlaubertModel(XLMModel):
# self attention # self attention
if not self.pre_norm: if not self.pre_norm:
attn_outputs = self.attentions[i]( attn_outputs = self.attentions[i](
tensor, attn_mask, cache=cache, head_mask=head_mask[i], output_attentions=output_attentions, tensor,
attn_mask,
cache=cache,
head_mask=head_mask[i],
output_attentions=output_attentions,
) )
attn = attn_outputs[0] attn = attn_outputs[0]
if output_attentions: if output_attentions:
......
...@@ -61,8 +61,7 @@ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -61,8 +61,7 @@ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """Load tf checkpoints in a pytorch model"""
"""
try: try:
import re import re
...@@ -324,7 +323,7 @@ class Block(nn.Module): ...@@ -324,7 +323,7 @@ class Block(nn.Module):
class GPT2PreTrainedModel(PreTrainedModel): class GPT2PreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -336,8 +335,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -336,8 +335,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -483,7 +481,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -483,7 +481,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.wte = new_embeddings self.wte = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
......
...@@ -135,7 +135,10 @@ class LongformerSelfAttention(nn.Module): ...@@ -135,7 +135,10 @@ class LongformerSelfAttention(nn.Module):
self.one_sided_attn_window_size = attention_window // 2 self.one_sided_attn_window_size = attention_window // 2
def forward( def forward(
self, hidden_states, attention_mask=None, output_attentions=False, self,
hidden_states,
attention_mask=None,
output_attentions=False,
): ):
""" """
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`. LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
...@@ -622,7 +625,10 @@ class LongformerSelfAttention(nn.Module): ...@@ -622,7 +625,10 @@ class LongformerSelfAttention(nn.Module):
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0 ] = -10000.0
global_attn_scores = global_attn_scores.masked_fill(is_index_masked[:, None, None, :], -10000.0,) global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
-10000.0,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
...@@ -676,9 +682,16 @@ class LongformerAttention(nn.Module): ...@@ -676,9 +682,16 @@ class LongformerAttention(nn.Module):
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
def forward( def forward(
self, hidden_states, attention_mask=None, output_attentions=False, self,
hidden_states,
attention_mask=None,
output_attentions=False,
): ):
self_outputs = self.self(hidden_states, attention_mask, output_attentions,) self_outputs = self.self(
hidden_states,
attention_mask,
output_attentions,
)
attn_output = self.output(self_outputs[0], hidden_states) attn_output = self.output(self_outputs[0], hidden_states)
outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -694,9 +707,16 @@ class LongformerLayer(nn.Module): ...@@ -694,9 +707,16 @@ class LongformerLayer(nn.Module):
self.seq_len_dim = 1 self.seq_len_dim = 1
def forward( def forward(
self, hidden_states, attention_mask=None, output_attentions=False, self,
hidden_states,
attention_mask=None,
output_attentions=False,
): ):
self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,) self_attn_outputs = self.attention(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
attn_output = self_attn_outputs[0] attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
...@@ -741,10 +761,16 @@ class LongformerEncoder(nn.Module): ...@@ -741,10 +761,16 @@ class LongformerEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, attention_mask, create_custom_forward(layer_module),
hidden_states,
attention_mask,
) )
else: else:
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,) layer_outputs = layer_module(
hidden_states,
attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
...@@ -762,7 +788,7 @@ class LongformerEncoder(nn.Module): ...@@ -762,7 +788,7 @@ class LongformerEncoder(nn.Module):
class LongformerPreTrainedModel(PreTrainedModel): class LongformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained a simple interface for downloading and loading pretrained
models. models.
""" """
...@@ -896,7 +922,7 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -896,7 +922,7 @@ class LongformerModel(LongformerPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
...@@ -938,7 +964,9 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -938,7 +964,9 @@ class LongformerModel(LongformerPreTrainedModel):
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id) position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None: if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full( input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len), self.config.pad_token_id, dtype=torch.long, (batch_size, padding_len),
self.config.pad_token_id,
dtype=torch.long,
) )
inputs_embeds_padding = self.embeddings(input_ids_padding) inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
...@@ -1252,7 +1280,10 @@ class LongformerForSequenceClassification(BertPreTrainedModel): ...@@ -1252,7 +1280,10 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1487,7 +1518,10 @@ class LongformerForTokenClassification(BertPreTrainedModel): ...@@ -1487,7 +1518,10 @@ class LongformerForTokenClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1592,5 +1626,8 @@ class LongformerForMultipleChoice(BertPreTrainedModel): ...@@ -1592,5 +1626,8 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -32,8 +32,7 @@ _CONFIG_FOR_DOC = "MMBTConfig" ...@@ -32,8 +32,7 @@ _CONFIG_FOR_DOC = "MMBTConfig"
class ModalEmbeddings(nn.Module): class ModalEmbeddings(nn.Module):
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding. """Generic Modal Embeddings which takes in an encoder, and a transformer embedding."""
"""
def __init__(self, config, encoder, embeddings): def __init__(self, config, encoder, embeddings):
super().__init__() super().__init__()
...@@ -154,7 +153,8 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs: ...@@ -154,7 +153,8 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs:
@add_start_docstrings( @add_start_docstrings(
"The bare MMBT Model outputting raw hidden-states without any specific head on top.", MMBT_START_DOCSTRING, "The bare MMBT Model outputting raw hidden-states without any specific head on top.",
MMBT_START_DOCSTRING,
) )
class MMBTModel(nn.Module, ModuleUtilsMixin): class MMBTModel(nn.Module, ModuleUtilsMixin):
def __init__(self, config, transformer, encoder): def __init__(self, config, transformer, encoder):
...@@ -378,5 +378,8 @@ class MMBTForClassification(nn.Module): ...@@ -378,5 +378,8 @@ class MMBTForClassification(nn.Module):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -64,8 +64,7 @@ MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"] ...@@ -64,8 +64,7 @@ MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]
def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model. """Load tf checkpoints in a pytorch model."""
"""
try: try:
import re import re
...@@ -161,8 +160,7 @@ NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm} ...@@ -161,8 +160,7 @@ NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm}
class MobileBertEmbeddings(nn.Module): class MobileBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -663,7 +661,7 @@ class MobileBertPreTrainingHeads(nn.Module): ...@@ -663,7 +661,7 @@ class MobileBertPreTrainingHeads(nn.Module):
class MobileBertPreTrainedModel(PreTrainedModel): class MobileBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -809,7 +807,7 @@ class MobileBertModel(MobileBertPreTrainedModel): ...@@ -809,7 +807,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
...@@ -1308,7 +1306,10 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): ...@@ -1308,7 +1306,10 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1491,7 +1492,10 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): ...@@ -1491,7 +1492,10 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput( return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -1574,5 +1578,8 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel): ...@@ -1574,5 +1578,8 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput( return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
...@@ -59,8 +59,7 @@ OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -59,8 +59,7 @@ OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here) """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
"""
import re import re
import numpy as np import numpy as np
...@@ -257,7 +256,10 @@ class Block(nn.Module): ...@@ -257,7 +256,10 @@ class Block(nn.Module):
def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False): def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
attn_outputs = self.attn( attn_outputs = self.attn(
x, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, x,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
) )
a = attn_outputs[0] a = attn_outputs[0]
...@@ -270,7 +272,7 @@ class Block(nn.Module): ...@@ -270,7 +272,7 @@ class Block(nn.Module):
class OpenAIGPTPreTrainedModel(PreTrainedModel): class OpenAIGPTPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -280,8 +282,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): ...@@ -280,8 +282,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """Initialize the weights."""
"""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)): if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
...@@ -408,7 +409,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -408,7 +409,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.tokens_embed = new_embeddings self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
...@@ -506,7 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -506,7 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions, last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
) )
......
...@@ -78,7 +78,8 @@ ReformerBackwardOutput = namedtuple( ...@@ -78,7 +78,8 @@ ReformerBackwardOutput = namedtuple(
"ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"]
) )
ReformerEncoderOutput = namedtuple( ReformerEncoderOutput = namedtuple(
"ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], "ReformerEncoderOutput",
["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
) )
...@@ -192,7 +193,9 @@ class AxialPositionEmbeddings(nn.Module): ...@@ -192,7 +193,9 @@ class AxialPositionEmbeddings(nn.Module):
assert ( assert (
reduce(mul, self.axial_pos_shape) >= sequence_length reduce(mul, self.axial_pos_shape) >= sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format( ), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format(
self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length, self.axial_pos_shape,
sequence_length,
self.least_common_mult_chunk_length,
) )
# compute how many columns are needed # compute how many columns are needed
...@@ -218,8 +221,7 @@ class AxialPositionEmbeddings(nn.Module): ...@@ -218,8 +221,7 @@ class AxialPositionEmbeddings(nn.Module):
class PositionEmbeddings(nn.Module): class PositionEmbeddings(nn.Module):
"""Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`. """Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -233,8 +235,7 @@ class PositionEmbeddings(nn.Module): ...@@ -233,8 +235,7 @@ class PositionEmbeddings(nn.Module):
class ReformerEmbeddings(nn.Module): class ReformerEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings. """Construct the embeddings from word, position and token_type embeddings."""
"""
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -285,7 +286,7 @@ class EfficientAttentionMixin: ...@@ -285,7 +286,7 @@ class EfficientAttentionMixin:
""" """
def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after): def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after):
""" Used to implement attention between consecutive chunks. """Used to implement attention between consecutive chunks.
Args: Args:
vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...] vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...]
...@@ -418,10 +419,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -418,10 +419,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# split key & value vectors by num hashes to apply # split key & value vectors by num hashes to apply
# self attention on each separately # self attention on each separately
query_key_vectors = self._split_seq_length_dim_to( query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, query_key_vectors,
num_hashes,
-1,
self.num_attention_heads,
self.attention_head_size,
) )
value_vectors = self._split_seq_length_dim_to( value_vectors = self._split_seq_length_dim_to(
value_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, value_vectors,
num_hashes,
-1,
self.num_attention_heads,
self.attention_head_size,
) )
# repeat query vectors across hash dimension # repeat query vectors across hash dimension
query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1) query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1)
...@@ -496,10 +505,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -496,10 +505,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes) query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes)
query_key_vectors = self._split_seq_length_dim_to( query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, query_key_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
value_vectors = self._split_seq_length_dim_to( value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, value_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
if self.chunk_length is None: if self.chunk_length is None:
...@@ -548,10 +565,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -548,10 +565,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# sum up all hash rounds # sum up all hash rounds
if num_hashes > 1: if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to( out_vectors = self._split_seq_length_dim_to(
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, out_vectors,
num_hashes,
sequence_length,
self.num_attention_heads,
self.attention_head_size,
) )
logits = self._split_seq_length_dim_to( logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size, logits,
num_hashes,
sequence_length,
self.num_attention_heads,
self.attention_head_size,
).unsqueeze(-1) ).unsqueeze(-1)
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True)) probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
...@@ -697,7 +722,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -697,7 +722,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# factorize `num_buckets` if `num_buckets` becomes too large # factorize `num_buckets` if `num_buckets` becomes too large
num_buckets_limit = 2 * max( num_buckets_limit = 2 * max(
int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length, int((self.max_position_embeddings // self.chunk_length) ** (0.5)),
self.chunk_length,
) )
if num_buckets > num_buckets_limit: if num_buckets > num_buckets_limit:
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)] num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
...@@ -1113,13 +1139,25 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -1113,13 +1139,25 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
# chunk vectors # chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to( query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, query_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
key_vectors = self._split_seq_length_dim_to( key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, key_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
value_vectors = self._split_seq_length_dim_to( value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, value_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
) )
# chunk indices # chunk indices
...@@ -1179,7 +1217,12 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -1179,7 +1217,12 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
if not do_standard_self_attention: if not do_standard_self_attention:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) assert out_vectors.shape == (
batch_size,
self.num_attention_heads,
sequence_length,
self.attention_head_size,
)
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size) out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
...@@ -1321,7 +1364,9 @@ class ReformerAttention(nn.Module): ...@@ -1321,7 +1364,9 @@ class ReformerAttention(nn.Module):
attention_output = self.output(self_attention_outputs.hidden_states) attention_output = self.output(self_attention_outputs.hidden_states)
return AttentionOutput( return AttentionOutput(
hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets, hidden_states=attention_output,
attention_probs=self_attention_outputs.attention_probs,
buckets=buckets,
) )
...@@ -1369,7 +1414,10 @@ class ChunkReformerFeedForward(nn.Module): ...@@ -1369,7 +1414,10 @@ class ChunkReformerFeedForward(nn.Module):
def forward(self, attention_output): def forward(self, attention_output):
return apply_chunking_to_forward( return apply_chunking_to_forward(
self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, self.forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
) )
def forward_chunk(self, hidden_states): def forward_chunk(self, hidden_states):
...@@ -1520,7 +1568,10 @@ class ReformerLayer(nn.Module): ...@@ -1520,7 +1568,10 @@ class ReformerLayer(nn.Module):
# f(X_2) # f(X_2)
# use cached buckets for backprob if buckets not None for LSHSelfAttention # use cached buckets for backprob if buckets not None for LSHSelfAttention
output = self.attention( output = self.attention(
hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets, hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
buckets=buckets,
).hidden_states ).hidden_states
output.backward(grad_attn_output, retain_graph=True) output.backward(grad_attn_output, retain_graph=True)
...@@ -1738,7 +1789,7 @@ class ReformerOnlyLMHead(nn.Module): ...@@ -1738,7 +1789,7 @@ class ReformerOnlyLMHead(nn.Module):
class ReformerPreTrainedModel(PreTrainedModel): class ReformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models. a simple interface for downloading and loading pretrained models.
""" """
...@@ -1947,7 +1998,7 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -1947,7 +1998,7 @@ class ReformerModel(ReformerPreTrainedModel):
self.embeddings.word_embeddings = value self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel See base class PreTrainedModel
""" """
...@@ -2099,7 +2150,10 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -2099,7 +2150,10 @@ class ReformerModel(ReformerPreTrainedModel):
) )
padded_input_ids = torch.full( padded_input_ids = torch.full(
(input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long, (input_shape[0], padding_length),
self.config.pad_token_id,
device=device,
dtype=torch.long,
) )
# Extend `attention_mask` # Extend `attention_mask`
...@@ -2407,7 +2461,10 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2407,7 +2461,10 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment