"INSTALL/vscode:/vscode.git/clone" did not exist on "ca62128f9bb4c349637be30cb8290919e03684f7"
Commit a75c64d8 authored by Lysandre's avatar Lysandre
Browse files

Black 20 release

parent e78c1103
......@@ -329,7 +329,13 @@ class TFGenerationMixin:
if self.config.is_encoder_decoder:
# create empty decoder_input_ids
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
input_ids = (
tf.ones(
(effective_batch_size * num_beams, 1),
dtype=tf.int32,
)
* decoder_start_token_id
)
cur_len = 1
assert (
......@@ -422,7 +428,7 @@ class TFGenerationMixin:
attention_mask,
use_cache,
):
""" Generate sequences for each example without beam search (num_beams == 1).
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
......@@ -587,8 +593,7 @@ class TFGenerationMixin:
attention_mask,
use_cache,
):
""" Generate sequences for each example with beam search.
"""
"""Generate sequences for each example with beam search."""
# generated hypotheses
generated_hyps = [
......@@ -960,7 +965,7 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
......@@ -1001,7 +1006,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = tf.concat(
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]],
-1,
)
# scatter sorted tensors to original indexing
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
......
......@@ -83,7 +83,11 @@ class GenerationMixin:
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(
scores, batch_size, num_beams, input_ids, repetition_penalty,
scores,
batch_size,
num_beams,
input_ids,
repetition_penalty,
)
# set eos token prob to zero if min_length is not reached
......@@ -324,7 +328,10 @@ class GenerationMixin:
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = torch.full(
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
(batch_size, 1),
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
else:
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
......@@ -514,7 +521,7 @@ class GenerationMixin:
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example without beam search (num_beams == 1).
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# length of generated sentences / unfinished sentences
......@@ -619,8 +626,7 @@ class GenerationMixin:
use_cache,
model_specific_kwargs,
):
""" Generate sequences for each example with beam search.
"""
"""Generate sequences for each example with beam search."""
# generated hypotheses
generated_hyps = [
......@@ -749,7 +755,8 @@ class GenerationMixin:
if is_beam_token_worse_than_top_num_beams:
continue
generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), beam_token_score.item(),
input_ids[effective_beam_id].clone(),
beam_token_score.item(),
)
else:
# add next predicted token since it is not eos_token
......@@ -806,7 +813,8 @@ class GenerationMixin:
assert torch.all(
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
next_scores[:, :num_beams][batch_idx],
beam_scores.view(batch_size, num_beams)[batch_idx],
)
# need to add best num_beams hypotheses to generated hyps
......@@ -916,7 +924,7 @@ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iter
def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
""" Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
"""Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
Args:
scores: logits distribution of shape (batch size, vocabulary size)
......@@ -946,7 +954,7 @@ def top_k_top_p_filtering(
filter_value: float = -float("Inf"),
min_tokens_to_keep: int = 1,
) -> Tensor:
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
......
......@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__)
class ModelCard:
r""" Structured Model Card class.
r"""Structured Model Card class.
Store model card as well as methods for loading/downloading/saving model cards.
Please read the following paper for details and explanation on the sections:
......@@ -73,8 +73,7 @@ class ModelCard:
raise err
def save_pretrained(self, save_directory_or_file):
""" Save a model card object to the directory or file `save_directory_or_file`.
"""
"""Save a model card object to the directory or file `save_directory_or_file`."""
if os.path.isdir(save_directory_or_file):
# If we save using the predefined names, we can load using `from_pretrained`
output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
......@@ -86,7 +85,7 @@ class ModelCard:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
r"""Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
Parameters:
pretrained_model_name_or_path: either:
......
......@@ -302,7 +302,10 @@ class AlbertLayer(nn.Module):
attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
ffn_output = apply_chunking_to_forward(
self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[0],
self.ff_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[0],
)
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
......@@ -397,7 +400,7 @@ class AlbertTransformer(nn.Module):
class AlbertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -406,8 +409,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights.
"""
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......@@ -543,7 +545,7 @@ class AlbertModel(AlbertPreTrainedModel):
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups.
If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there
......@@ -787,7 +789,8 @@ class AlbertSOPHead(nn.Module):
@add_start_docstrings(
"Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING,
"Albert Model with a `language modeling` head on top.",
ALBERT_START_DOCSTRING,
)
class AlbertForMaskedLM(AlbertPreTrainedModel):
def __init__(self, config):
......@@ -952,7 +955,10 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1033,7 +1039,10 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1215,5 +1224,8 @@ class AlbertForMultipleChoice(AlbertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -391,7 +391,7 @@ class AutoModel:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -433,7 +433,7 @@ class AutoModel:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -547,7 +547,7 @@ class AutoModelForPreTraining:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -589,7 +589,7 @@ class AutoModelForPreTraining:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.
r"""Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
based on the `model_type` property of the config object, or when it's missing,
......@@ -697,7 +697,7 @@ class AutoModelWithLMHead:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -743,7 +743,7 @@ class AutoModelWithLMHead:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -856,7 +856,7 @@ class AutoModelForCausalLM:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -893,7 +893,7 @@ class AutoModelForCausalLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -993,7 +993,7 @@ class AutoModelForMaskedLM:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1033,7 +1033,7 @@ class AutoModelForMaskedLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1136,7 +1136,7 @@ class AutoModelForSeq2SeqLM:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1172,7 +1172,7 @@ class AutoModelForSeq2SeqLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1271,7 +1271,7 @@ class AutoModelForSequenceClassification:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1313,7 +1313,7 @@ class AutoModelForSequenceClassification:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the sequence classification model classes of the library
r"""Instantiates one of the sequence classification model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1423,7 +1423,7 @@ class AutoModelForQuestionAnswering:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1462,7 +1462,7 @@ class AutoModelForQuestionAnswering:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library
r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1568,7 +1568,7 @@ class AutoModelForTokenClassification:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1611,7 +1611,7 @@ class AutoModelForTokenClassification:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library
r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......
......@@ -223,7 +223,9 @@ class EncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = SelfAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout,
self.embed_dim,
config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
......@@ -297,7 +299,10 @@ class BartEncoder(nn.Module):
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings,
config.max_position_embeddings,
embed_dim,
self.padding_idx,
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
......@@ -370,7 +375,9 @@ class DecoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = SelfAttention(
embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout,
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
......@@ -477,7 +484,10 @@ class BartDecoder(nn.Module):
)
else:
self.embed_positions = LearnedPositionalEmbedding(
config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings,
config.max_position_embeddings,
config.d_model,
self.padding_idx,
config.extra_pos_embeddings,
)
self.layers = nn.ModuleList(
[DecoderLayer(config) for _ in range(config.decoder_layers)]
......@@ -695,7 +705,10 @@ class SelfAttention(nn.Module):
# This is part of a workaround to get around fork/join parallelism not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,)
assert key_padding_mask is None or key_padding_mask.size()[:2] == (
bsz,
src_len,
)
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
......@@ -703,7 +716,11 @@ class SelfAttention(nn.Module):
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
attn_probs = F.dropout(
attn_weights,
p=self.dropout,
training=self.training,
)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
......@@ -754,7 +771,11 @@ class BartClassificationHead(nn.Module):
# This can trivially be shared with RobertaClassificationHead
def __init__(
self, input_dim, inner_dim, num_classes, pooler_dropout,
self,
input_dim,
inner_dim,
num_classes,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
......@@ -819,7 +840,8 @@ def _get_shape(t):
@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING,
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class BartModel(PretrainedBartModel):
def __init__(self, config: BartConfig):
......@@ -1116,7 +1138,10 @@ class BartForSequenceClassification(PretrainedBartModel):
super().__init__(config, **kwargs)
self.model = BartModel(config)
self.classification_head = BartClassificationHead(
config.d_model, config.d_model, config.num_labels, config.classif_dropout,
config.d_model,
config.d_model,
config.num_labels,
config.classif_dropout,
)
self.model._init_weights(self.classification_head.dense)
self.model._init_weights(self.classification_head.out_proj)
......@@ -1279,7 +1304,10 @@ class BartForQuestionAnswering(PretrainedBartModel):
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits,) + outputs[1:]
output = (
start_logits,
end_logits,
) + outputs[1:]
return ((total_loss,) + output) if total_loss is not None else output
return Seq2SeqQuestionAnsweringModelOutput(
......
......@@ -89,8 +89,7 @@ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.
"""
"""Load tf checkpoints in a pytorch model."""
try:
import re
......@@ -174,8 +173,7 @@ BertLayerNorm = torch.nn.LayerNorm
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
......@@ -343,7 +341,12 @@ class BertAttention(nn.Module):
output_attentions=False,
):
self_outputs = self.self(
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
......@@ -403,7 +406,10 @@ class BertLayer(nn.Module):
output_attentions=False,
):
self_attention_outputs = self.attention(
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
......@@ -582,7 +588,7 @@ class BertPreTrainingHeads(nn.Module):
class BertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -733,7 +739,7 @@ class BertModel(BertPreTrainedModel):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -1049,7 +1055,10 @@ class BertLMHeadModel(BertPreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput(
loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
......@@ -1173,7 +1182,8 @@ class BertForMaskedLM(BertPreTrainedModel):
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class BertForNextSentencePrediction(BertPreTrainedModel):
def __init__(self, config):
......@@ -1336,7 +1346,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1422,7 +1435,10 @@ class BertForMultipleChoice(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1505,7 +1521,10 @@ class BertForTokenClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -68,7 +68,8 @@ class CamembertModel(RobertaModel):
@add_start_docstrings(
"""CamemBERT Model with a `language modeling` head on top. """, CAMEMBERT_START_DOCSTRING,
"""CamemBERT Model with a `language modeling` head on top. """,
CAMEMBERT_START_DOCSTRING,
)
class CamembertForMaskedLM(RobertaForMaskedLM):
"""
......
......@@ -212,7 +212,7 @@ class EncoderLayer(torch.nn.Module):
class CTRLPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -220,8 +220,7 @@ class CTRLPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
def _init_weights(self, module):
""" Initialize the weights.
"""
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......@@ -331,7 +330,7 @@ class CTRLModel(CTRLPreTrainedModel):
self.w = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
......
......@@ -261,7 +261,12 @@ class TransformerBlock(nn.Module):
"""
# Self-Attention
sa_output = self.attention(
query=x, key=x, value=x, mask=attn_mask, head_mask=head_mask, output_attentions=output_attentions,
query=x,
key=x,
value=x,
mask=attn_mask,
head_mask=head_mask,
output_attentions=output_attentions,
)
if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
......@@ -343,7 +348,7 @@ class Transformer(nn.Module):
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class DistilBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -352,8 +357,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "distilbert"
def _init_weights(self, module):
""" Initialize the weights.
"""
"""Initialize the weights."""
if isinstance(module, nn.Embedding):
if module.weight.requires_grad:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
......@@ -432,7 +436,7 @@ class DistilBertModel(DistilBertPreTrainedModel):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -493,7 +497,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
@add_start_docstrings(
"""DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING,
"""DistilBert Model with a `masked language modeling` head on top. """,
DISTILBERT_START_DOCSTRING,
)
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
def __init__(self, config):
......@@ -829,7 +834,10 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -930,5 +938,8 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -265,7 +265,7 @@ class DPRSpanPredictor(PreTrainedModel):
class DPRPretrainedContextEncoder(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -278,7 +278,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
class DPRPretrainedQuestionEncoder(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -291,7 +291,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
class DPRPretrainedReader(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -553,7 +553,8 @@ class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
@add_start_docstrings(
"The bare DPRReader transformer outputting span predictions.", DPR_START_DOCSTRING,
"The bare DPRReader transformer outputting span predictions.",
DPR_START_DOCSTRING,
)
class DPRReader(DPRPretrainedReader):
def __init__(self, config: DPRConfig):
......
......@@ -46,8 +46,7 @@ ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
""" Load tf checkpoints in a pytorch model.
"""
"""Load tf checkpoints in a pytorch model."""
try:
import re
......@@ -179,7 +178,7 @@ class ElectraGeneratorPredictions(nn.Module):
class ElectraPreTrainedModel(BertPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -311,7 +310,7 @@ class ElectraModel(ElectraPreTrainedModel):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -836,7 +835,10 @@ class ElectraForQuestionAnswering(ElectraPreTrainedModel):
total_loss = (start_loss + end_loss) / 2
if not return_dict:
output = (start_logits, end_logits,) + discriminator_hidden_states[1:]
output = (
start_logits,
end_logits,
) + discriminator_hidden_states[1:]
return ((total_loss,) + output) if total_loss is not None else output
return QuestionAnsweringModelOutput(
......
......@@ -103,7 +103,7 @@ class EncoderDecoderModel(PreTrainedModel):
*model_args,
**kwargs
) -> PreTrainedModel:
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
r"""Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
......
......@@ -240,7 +240,11 @@ class FlaubertModel(XLMModel):
# self attention
if not self.pre_norm:
attn_outputs = self.attentions[i](
tensor, attn_mask, cache=cache, head_mask=head_mask[i], output_attentions=output_attentions,
tensor,
attn_mask,
cache=cache,
head_mask=head_mask[i],
output_attentions=output_attentions,
)
attn = attn_outputs[0]
if output_attentions:
......
......@@ -61,8 +61,7 @@ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
"""Load tf checkpoints in a pytorch model"""
try:
import re
......@@ -324,7 +323,7 @@ class Block(nn.Module):
class GPT2PreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -336,8 +335,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
""" Initialize the weights.
"""
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......@@ -483,7 +481,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.wte = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
......
......@@ -135,7 +135,10 @@ class LongformerSelfAttention(nn.Module):
self.one_sided_attn_window_size = attention_window // 2
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
self,
hidden_states,
attention_mask=None,
output_attentions=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
......@@ -622,7 +625,10 @@ class LongformerSelfAttention(nn.Module):
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
] = -10000.0
global_attn_scores = global_attn_scores.masked_fill(is_index_masked[:, None, None, :], -10000.0,)
global_attn_scores = global_attn_scores.masked_fill(
is_index_masked[:, None, None, :],
-10000.0,
)
global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
......@@ -676,9 +682,16 @@ class LongformerAttention(nn.Module):
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
self,
hidden_states,
attention_mask=None,
output_attentions=False,
):
self_outputs = self.self(hidden_states, attention_mask, output_attentions,)
self_outputs = self.self(
hidden_states,
attention_mask,
output_attentions,
)
attn_output = self.output(self_outputs[0], hidden_states)
outputs = (attn_output,) + self_outputs[1:] # add attentions if we output them
return outputs
......@@ -694,9 +707,16 @@ class LongformerLayer(nn.Module):
self.seq_len_dim = 1
def forward(
self, hidden_states, attention_mask=None, output_attentions=False,
self,
hidden_states,
attention_mask=None,
output_attentions=False,
):
self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,)
self_attn_outputs = self.attention(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
attn_output = self_attn_outputs[0]
outputs = self_attn_outputs[1:] # add self attentions if we output attention weights
......@@ -741,10 +761,16 @@ class LongformerEncoder(nn.Module):
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, attention_mask,
create_custom_forward(layer_module),
hidden_states,
attention_mask,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions,)
layer_outputs = layer_module(
hidden_states,
attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
......@@ -762,7 +788,7 @@ class LongformerEncoder(nn.Module):
class LongformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained
models.
"""
......@@ -896,7 +922,7 @@ class LongformerModel(LongformerPreTrainedModel):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -938,7 +964,9 @@ class LongformerModel(LongformerPreTrainedModel):
position_ids = F.pad(position_ids, (0, padding_len), value=pad_token_id)
if inputs_embeds is not None:
input_ids_padding = inputs_embeds.new_full(
(batch_size, padding_len), self.config.pad_token_id, dtype=torch.long,
(batch_size, padding_len),
self.config.pad_token_id,
dtype=torch.long,
)
inputs_embeds_padding = self.embeddings(input_ids_padding)
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
......@@ -1252,7 +1280,10 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1487,7 +1518,10 @@ class LongformerForTokenClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1592,5 +1626,8 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -32,8 +32,7 @@ _CONFIG_FOR_DOC = "MMBTConfig"
class ModalEmbeddings(nn.Module):
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding.
"""
"""Generic Modal Embeddings which takes in an encoder, and a transformer embedding."""
def __init__(self, config, encoder, embeddings):
super().__init__()
......@@ -154,7 +153,8 @@ MMBT_INPUTS_DOCSTRING = r""" Inputs:
@add_start_docstrings(
"The bare MMBT Model outputting raw hidden-states without any specific head on top.", MMBT_START_DOCSTRING,
"The bare MMBT Model outputting raw hidden-states without any specific head on top.",
MMBT_START_DOCSTRING,
)
class MMBTModel(nn.Module, ModuleUtilsMixin):
def __init__(self, config, transformer, encoder):
......@@ -378,5 +378,8 @@ class MMBTForClassification(nn.Module):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -64,8 +64,7 @@ MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]
def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.
"""
"""Load tf checkpoints in a pytorch model."""
try:
import re
......@@ -161,8 +160,7 @@ NORM2FN = {"layer_norm": torch.nn.LayerNorm, "no_norm": NoNorm}
class MobileBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
......@@ -663,7 +661,7 @@ class MobileBertPreTrainingHeads(nn.Module):
class MobileBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -809,7 +807,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -1308,7 +1306,10 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1491,7 +1492,10 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1574,5 +1578,8 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -59,8 +59,7 @@ OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
"""
"""Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
import re
import numpy as np
......@@ -257,7 +256,10 @@ class Block(nn.Module):
def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
attn_outputs = self.attn(
x, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions,
x,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
)
a = attn_outputs[0]
......@@ -270,7 +272,7 @@ class Block(nn.Module):
class OpenAIGPTPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -280,8 +282,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights.
"""
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......@@ -408,7 +409,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self.tokens_embed = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
......@@ -506,7 +507,9 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions,
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
......
......@@ -78,7 +78,8 @@ ReformerBackwardOutput = namedtuple(
"ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"]
)
ReformerEncoderOutput = namedtuple(
"ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
"ReformerEncoderOutput",
["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
)
......@@ -192,7 +193,9 @@ class AxialPositionEmbeddings(nn.Module):
assert (
reduce(mul, self.axial_pos_shape) >= sequence_length
), "Make sure that config.axial_pos_shape factors: {} multiply at least to max(sequence_length, least_common_mult_chunk_length): max({}, {})".format(
self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length,
self.axial_pos_shape,
sequence_length,
self.least_common_mult_chunk_length,
)
# compute how many columns are needed
......@@ -218,8 +221,7 @@ class AxialPositionEmbeddings(nn.Module):
class PositionEmbeddings(nn.Module):
"""Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`.
"""
"""Constructs conventional position embeddings of shape `[max_pos_embeddings, hidden_size]`."""
def __init__(self, config):
super().__init__()
......@@ -233,8 +235,7 @@ class PositionEmbeddings(nn.Module):
class ReformerEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config):
super().__init__()
......@@ -285,7 +286,7 @@ class EfficientAttentionMixin:
"""
def _look_adjacent(self, vectors, num_chunks_before, num_chunks_after):
""" Used to implement attention between consecutive chunks.
"""Used to implement attention between consecutive chunks.
Args:
vectors: array of shape [batch_size, num_attention_heads, n_chunks, chunk_len, ...]
......@@ -418,10 +419,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# split key & value vectors by num hashes to apply
# self attention on each separately
query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size,
query_key_vectors,
num_hashes,
-1,
self.num_attention_heads,
self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size,
value_vectors,
num_hashes,
-1,
self.num_attention_heads,
self.attention_head_size,
)
# repeat query vectors across hash dimension
query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1)
......@@ -496,10 +505,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes)
value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes)
query_key_vectors = self._split_seq_length_dim_to(
query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
query_key_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
value_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
)
if self.chunk_length is None:
......@@ -548,10 +565,18 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# sum up all hash rounds
if num_hashes > 1:
out_vectors = self._split_seq_length_dim_to(
out_vectors, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
out_vectors,
num_hashes,
sequence_length,
self.num_attention_heads,
self.attention_head_size,
)
logits = self._split_seq_length_dim_to(
logits, num_hashes, sequence_length, self.num_attention_heads, self.attention_head_size,
logits,
num_hashes,
sequence_length,
self.num_attention_heads,
self.attention_head_size,
).unsqueeze(-1)
probs_vectors = torch.exp(logits - torch.logsumexp(logits, dim=2, keepdim=True))
......@@ -697,7 +722,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# factorize `num_buckets` if `num_buckets` becomes too large
num_buckets_limit = 2 * max(
int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,
int((self.max_position_embeddings // self.chunk_length) ** (0.5)),
self.chunk_length,
)
if num_buckets > num_buckets_limit:
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
......@@ -1113,13 +1139,25 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
# chunk vectors
# B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size
query_vectors = self._split_seq_length_dim_to(
query_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
query_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
)
key_vectors = self._split_seq_length_dim_to(
key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
key_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
)
value_vectors = self._split_seq_length_dim_to(
value_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size,
value_vectors,
-1,
self.chunk_length,
self.num_attention_heads,
self.attention_head_size,
)
# chunk indices
......@@ -1179,7 +1217,12 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
if not do_standard_self_attention:
out_vectors = out_vectors.flatten(start_dim=2, end_dim=3)
assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,)
assert out_vectors.shape == (
batch_size,
self.num_attention_heads,
sequence_length,
self.attention_head_size,
)
out_vectors = self._merge_hidden_size_dims(out_vectors, self.num_attention_heads, self.attention_head_size)
......@@ -1321,7 +1364,9 @@ class ReformerAttention(nn.Module):
attention_output = self.output(self_attention_outputs.hidden_states)
return AttentionOutput(
hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets,
hidden_states=attention_output,
attention_probs=self_attention_outputs.attention_probs,
buckets=buckets,
)
......@@ -1369,7 +1414,10 @@ class ChunkReformerFeedForward(nn.Module):
def forward(self, attention_output):
return apply_chunking_to_forward(
self.forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output,
self.forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
def forward_chunk(self, hidden_states):
......@@ -1520,7 +1568,10 @@ class ReformerLayer(nn.Module):
# f(X_2)
# use cached buckets for backprob if buckets not None for LSHSelfAttention
output = self.attention(
hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, buckets=buckets,
hidden_states=hidden_states,
head_mask=head_mask,
attention_mask=attention_mask,
buckets=buckets,
).hidden_states
output.backward(grad_attn_output, retain_graph=True)
......@@ -1738,7 +1789,7 @@ class ReformerOnlyLMHead(nn.Module):
class ReformerPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -1947,7 +1998,7 @@ class ReformerModel(ReformerPreTrainedModel):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -2099,7 +2150,10 @@ class ReformerModel(ReformerPreTrainedModel):
)
padded_input_ids = torch.full(
(input_shape[0], padding_length), self.config.pad_token_id, device=device, dtype=torch.long,
(input_shape[0], padding_length),
self.config.pad_token_id,
device=device,
dtype=torch.long,
)
# Extend `attention_mask`
......@@ -2407,7 +2461,10 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment