Commit a75c64d8 authored by Lysandre's avatar Lysandre
Browse files

Black 20 release

parent e78c1103
......@@ -40,7 +40,7 @@ RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class RetriBertPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -73,7 +73,8 @@ RETRIBERT_START_DOCSTRING = r"""
@add_start_docstrings(
"""Bert Based model to embed queries or document for document retreival. """, RETRIBERT_START_DOCSTRING,
"""Bert Based model to embed queries or document for document retreival. """,
RETRIBERT_START_DOCSTRING,
)
class RetriBertModel(RetriBertPreTrainedModel):
def __init__(self, config):
......@@ -91,7 +92,11 @@ class RetriBertModel(RetriBertPreTrainedModel):
self.init_weights()
def embed_sentences_checkpointed(
self, input_ids, attention_mask, sent_encoder, checkpoint_batch_size=-1,
self,
input_ids,
attention_mask,
sent_encoder,
checkpoint_batch_size=-1,
):
# reproduces BERT forward pass with checkpointing
if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
......@@ -108,7 +113,11 @@ class RetriBertModel(RetriBertPreTrainedModel):
# define function for cehckpointing
def partial_encode(*inputs):
encoder_outputs = sent_encoder.encoder(inputs[0], attention_mask=inputs[1], head_mask=head_mask,)
encoder_outputs = sent_encoder.encoder(
inputs[0],
attention_mask=inputs[1],
head_mask=head_mask,
)
sequence_output = encoder_outputs[0]
pooled_output = sent_encoder.pooler(sequence_output)
return pooled_output
......@@ -127,13 +136,24 @@ class RetriBertModel(RetriBertPreTrainedModel):
return torch.cat(pooled_output_list, dim=0)
def embed_questions(
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
self,
input_ids,
attention_mask=None,
checkpoint_batch_size=-1,
):
q_reps = self.embed_sentences_checkpointed(input_ids, attention_mask, self.bert_query, checkpoint_batch_size,)
q_reps = self.embed_sentences_checkpointed(
input_ids,
attention_mask,
self.bert_query,
checkpoint_batch_size,
)
return self.project_query(q_reps)
def embed_answers(
self, input_ids, attention_mask=None, checkpoint_batch_size=-1,
self,
input_ids,
attention_mask=None,
checkpoint_batch_size=-1,
):
a_reps = self.embed_sentences_checkpointed(
input_ids,
......
......@@ -83,7 +83,7 @@ class RobertaEmbeddings(BertEmbeddings):
)
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
""" We are provided embeddings directly. We cannot infer which are padded so just generate
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param torch.Tensor inputs_embeds:
......@@ -283,7 +283,10 @@ class RobertaForCausalLM(BertPreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput(
loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
......@@ -493,7 +496,10 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -581,7 +587,10 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -667,7 +676,10 @@ class RobertaForTokenClassification(BertPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -791,7 +803,7 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
def create_position_ids_from_input_ids(input_ids, padding_idx):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
......
......@@ -62,8 +62,7 @@ T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.
"""
"""Load tf checkpoints in a pytorch model."""
try:
import re
......@@ -156,7 +155,7 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
""" Construct a layernorm module in the T5 style
"""Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super().__init__()
......@@ -569,7 +568,7 @@ class T5Block(nn.Module):
class T5PreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -913,7 +912,7 @@ class T5Model(T5PreTrainedModel):
return self.decoder
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......
......@@ -74,8 +74,7 @@ TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFAlbertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.
"""
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -478,7 +477,7 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
class TFAlbertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -551,7 +550,7 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -655,7 +654,10 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
pooled_output = self.pooler(sequence_output[:, 0])
if not return_dict:
return (sequence_output, pooled_output,) + encoder_outputs[1:]
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
......@@ -856,7 +858,9 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)
self.classifier = tf.keras.layers.Dense(
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier",
config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
def call(self, pooled_output, training: bool):
......@@ -935,7 +939,10 @@ class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss)
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1016,7 +1023,10 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1095,7 +1105,10 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1211,7 +1224,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
"""Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
......@@ -1316,5 +1329,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -342,7 +342,7 @@ class TFAutoModel(object):
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -381,7 +381,7 @@ class TFAutoModel(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -493,7 +493,7 @@ class TFAutoModelForPreTraining(object):
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -532,7 +532,7 @@ class TFAutoModelForPreTraining(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.
r"""Instantiates one of the model classes of the library -with the architecture used for pretraining this model– from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
based on the `model_type` property of the config object, or when it's missing,
......@@ -662,7 +662,7 @@ class TFAutoModelWithLMHead(object):
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -705,7 +705,7 @@ class TFAutoModelWithLMHead(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -831,7 +831,7 @@ class TFAutoModelForMultipleChoice:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -864,7 +864,7 @@ class TFAutoModelForMultipleChoice:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the multiple choice model classes of the library
r"""Instantiates one of the multiple choice model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -975,7 +975,7 @@ class TFAutoModelForCausalLM:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1011,7 +1011,7 @@ class TFAutoModelForCausalLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1110,7 +1110,7 @@ class TFAutoModelForMaskedLM:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1149,7 +1149,7 @@ class TFAutoModelForMaskedLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1252,7 +1252,7 @@ class TFAutoModelForSeq2SeqLM:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1285,7 +1285,7 @@ class TFAutoModelForSeq2SeqLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the language modeling model classes of the library
r"""Instantiates one of the language modeling model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1391,7 +1391,7 @@ class TFAutoModelForSequenceClassification(object):
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1428,7 +1428,7 @@ class TFAutoModelForSequenceClassification(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the sequence classification model classes of the library
r"""Instantiates one of the sequence classification model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1553,7 +1553,7 @@ class TFAutoModelForQuestionAnswering(object):
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1591,7 +1591,7 @@ class TFAutoModelForQuestionAnswering(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library
r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......@@ -1697,7 +1697,7 @@ class TFAutoModelForTokenClassification:
@classmethod
def from_config(cls, config):
r""" Instantiates one of the base model classes of the library
r"""Instantiates one of the base model classes of the library
from a configuration.
Note:
......@@ -1733,7 +1733,7 @@ class TFAutoModelForTokenClassification:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r""" Instantiates one of the question answering model classes of the library
r"""Instantiates one of the question answering model classes of the library
from a pre-trained model configuration.
The `from_pretrained()` method takes care of returning the correct model class instance
......
......@@ -89,7 +89,7 @@ TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def gelu(x):
""" Gaussian Error Linear Unit.
"""Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
......@@ -127,8 +127,7 @@ ACT2FN = {
class TFBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.
"""
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -551,7 +550,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -656,7 +655,10 @@ class TFBertMainLayer(tf.keras.layers.Layer):
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output, pooled_output,) + encoder_outputs[1:]
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
......@@ -667,7 +669,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
class TFBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -933,7 +935,10 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1011,12 +1016,16 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
return ((loss,) + output) if loss is not None else output
return TFCausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
BERT_START_DOCSTRING,
)
class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
......@@ -1057,7 +1066,9 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
return (seq_relationship_score,) + outputs[2:]
return TFNextSentencePredictorOutput(
logits=seq_relationship_score, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1138,7 +1149,10 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1159,7 +1173,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
"""Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
......@@ -1261,7 +1275,10 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1340,7 +1357,10 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -77,7 +77,8 @@ class TFCamembertModel(TFRobertaModel):
@add_start_docstrings(
"""CamemBERT Model with a `language modeling` head on top. """, CAMEMBERT_START_DOCSTRING,
"""CamemBERT Model with a `language modeling` head on top. """,
CAMEMBERT_START_DOCSTRING,
)
class TFCamembertForMaskedLM(TFRobertaForMaskedLM):
"""
......
......@@ -245,7 +245,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
raise NotImplementedError
......@@ -426,7 +426,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
class TFCTRLPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......
......@@ -70,7 +70,7 @@ TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x):
""" Gaussian Error Linear Unit.
"""Gaussian Error Linear Unit.
Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
......@@ -518,7 +518,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
class TFDistilBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -634,7 +634,8 @@ class TFDistilBertLMHead(tf.keras.layers.Layer):
@add_start_docstrings(
"""DistilBert Model with a `masked language modeling` head on top. """, DISTILBERT_START_DOCSTRING,
"""DistilBert Model with a `masked language modeling` head on top. """,
DISTILBERT_START_DOCSTRING,
)
class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModelingLoss):
def __init__(self, config, *inputs, **kwargs):
......@@ -875,7 +876,10 @@ class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenCla
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -902,7 +906,7 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
"""Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
......
......@@ -54,8 +54,7 @@ TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFElectraEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.
"""
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -94,7 +93,13 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
super().build(input_shape)
def call(
self, input_ids, position_ids=None, token_type_ids=None, inputs_embeds=None, mode="embedding", training=False,
self,
input_ids,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
mode="embedding",
training=False,
):
"""Get token embeddings of inputs.
Args:
......@@ -250,7 +255,7 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -729,7 +734,10 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -752,7 +760,7 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
"""Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
......@@ -853,7 +861,10 @@ class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss)
return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1020,7 +1031,10 @@ class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnswerin
loss = self.compute_loss(labels, (start_logits, end_logits))
if not return_dict:
output = (start_logits, end_logits,) + discriminator_hidden_states[1:]
output = (
start_logits,
end_logits,
) + discriminator_hidden_states[1:]
return ((loss,) + output) if loss is not None else output
return TFQuestionAnsweringModelOutput(
......
......@@ -252,7 +252,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
self.wte.vocab_size = self.wte.weight.shape[0]
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
raise NotImplementedError
......@@ -417,7 +417,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
class TFGPT2PreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......
......@@ -72,11 +72,14 @@ def _compute_global_attention_mask(input_ids_shape, sep_token_indices, before_se
)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
attention_mask = tf.cast(
attention_mask = (
tf.cast(
tf.broadcast_to(attention_mask, input_ids_shape)
> tf.broadcast_to(question_end_index + 1, input_ids_shape),
tf.dtypes.int32,
) * tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32)
)
* tf.cast(tf.broadcast_to(attention_mask, input_ids_shape) < input_ids_shape[-1], tf.dtypes.int32)
)
return attention_mask
......@@ -130,7 +133,9 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
self.one_sided_attn_window_size = attention_window // 2
def call(
self, inputs, training=False,
self,
inputs,
training=False,
):
"""
LongformerSelfAttention expects `len(hidden_states)` to be multiple of `attention_window`.
......@@ -779,7 +784,8 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
tf.transpose(global_attn_output, (0, 2, 1, 3)), is_local_index_global_attn_nonzero
)
nonzero_global_attn_output = tf.reshape(
nonzero_global_attn_output, (shape_list(is_local_index_global_attn_nonzero)[0], -1),
nonzero_global_attn_output,
(shape_list(is_local_index_global_attn_nonzero)[0], -1),
)
# overwrite values with global attention
......@@ -910,7 +916,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
self.embeddings.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -1021,7 +1027,10 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
sequence_output = sequence_output[:, :-padding_len]
if not return_dict:
return (sequence_output, pooled_output,) + encoder_outputs[1:]
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
......@@ -1031,7 +1040,13 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
)
def _pad_to_window_size(
self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, pad_token_id,
self,
input_ids,
attention_mask,
token_type_ids,
position_ids,
inputs_embeds,
pad_token_id,
):
"""A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
# padding
......@@ -1083,7 +1098,7 @@ class TFLongformerMainLayer(tf.keras.layers.Layer):
class TFLongformerPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -1286,7 +1301,10 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -99,8 +99,7 @@ NORM2FN = {"layer_norm": TFLayerNorm, "no_norm": TFNoNorm}
class TFMobileBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings.
"""
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
......@@ -696,7 +695,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
raise NotImplementedError
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
......@@ -799,7 +798,10 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output, pooled_output,) + encoder_outputs[1:]
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
......@@ -810,7 +812,7 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
class TFMobileBertPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -1069,7 +1071,10 @@ class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModel
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1125,7 +1130,9 @@ class TFMobileBertForNextSentencePrediction(TFMobileBertPreTrainedModel):
return (seq_relationship_score,) + outputs[2:]
return TFNextSentencePredictorOutput(
logits=seq_relationship_score, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1206,7 +1213,10 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1323,7 +1333,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
"""Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
......@@ -1425,7 +1435,10 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -1504,5 +1517,8 @@ class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenCla
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -243,7 +243,7 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
self.tokens_embed.vocab_size = value.shape[0]
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
raise NotImplementedError
......@@ -373,12 +373,14 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions,
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......
......@@ -28,7 +28,7 @@ logger = logging.get_logger(__name__)
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""):
""" Convert a TF 2.0 model variable name in a pytorch model weight name.
"""Convert a TF 2.0 model variable name in a pytorch model weight name.
Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
- '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
......@@ -72,8 +72,7 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""Load pytorch checkpoints in a TF 2.0 model"""
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
......@@ -96,8 +95,7 @@ def load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path, tf_i
def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch checkpoints in a TF 2.0 model
"""
"""Load pytorch checkpoints in a TF 2.0 model"""
pt_state_dict = pt_model.state_dict()
return load_pytorch_weights_in_tf2_model(
......@@ -106,8 +104,7 @@ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_mi
def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False):
""" Load pytorch state_dict in a TF 2.0 model.
"""
"""Load pytorch state_dict in a TF 2.0 model."""
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
......@@ -230,7 +227,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False):
""" Load TF 2.0 HDF5 checkpoint in a PyTorch model
"""Load TF 2.0 HDF5 checkpoint in a PyTorch model
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
......@@ -265,16 +262,14 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False):
""" Load TF 2.0 model in a pytorch model
"""
"""Load TF 2.0 model in a pytorch model"""
weights = tf_model.weights
return load_tf2_weights_in_pytorch_model(pt_model, weights, allow_missing_keys=allow_missing_keys)
def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False):
""" Load TF2.0 symbolic weights in a PyTorch model
"""
"""Load TF2.0 symbolic weights in a PyTorch model"""
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
......
......@@ -73,7 +73,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
self.padding_idx = 1
def create_position_ids_from_input_ids(self, x):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
"""Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param tf.Tensor x:
......@@ -84,7 +84,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
return incremental_indicies + self.padding_idx
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
""" We are provided embeddings directly. We cannot infer which are padded so just generate
"""We are provided embeddings directly. We cannot infer which are padded so just generate
sequential position ids.
:param tf.Tensor inputs_embeds:
:return tf.Tensor:
......@@ -120,7 +120,7 @@ class TFRobertaMainLayer(TFBertMainLayer):
class TFRobertaPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -330,7 +330,10 @@ class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLos
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -431,7 +434,10 @@ class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceCla
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -452,7 +458,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
"""Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
......@@ -549,7 +555,10 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -628,7 +637,10 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -67,7 +67,7 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
class TFT5LayerNorm(tf.keras.layers.Layer):
def __init__(self, epsilon=1e-6, **kwargs):
""" Construct a layernorm module in the T5 style
"""Construct a layernorm module in the T5 style
No bias and no substraction of mean.
"""
super().__init__(**kwargs)
......@@ -140,7 +140,9 @@ class TFT5Attention(tf.keras.layers.Layer):
if self.has_relative_attention_bias:
self.relative_attention_bias = tf.keras.layers.Embedding(
self.relative_attention_num_buckets, self.n_heads, name="relative_attention_bias",
self.relative_attention_num_buckets,
self.n_heads,
name="relative_attention_bias",
)
self.pruned_heads = set()
......@@ -199,7 +201,9 @@ class TFT5Attention(tf.keras.layers.Layer):
memory_position = tf.range(klen)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position, bidirectional=not self.is_decoder, num_buckets=self.relative_attention_num_buckets,
relative_position,
bidirectional=not self.is_decoder,
num_buckets=self.relative_attention_num_buckets,
)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen)
......@@ -316,7 +320,9 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs)
self.SelfAttention = TFT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, name="SelfAttention",
config,
has_relative_attention_bias=has_relative_attention_bias,
name="SelfAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
......@@ -353,7 +359,9 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
super().__init__(**kwargs)
self.EncDecAttention = TFT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, name="EncDecAttention",
config,
has_relative_attention_bias=has_relative_attention_bias,
name="EncDecAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
......@@ -396,12 +404,18 @@ class TFT5Block(tf.keras.layers.Layer):
self.is_decoder = config.is_decoder
self.layer = []
self.layer.append(
TFT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._0",)
TFT5LayerSelfAttention(
config,
has_relative_attention_bias=has_relative_attention_bias,
name="layer_._0",
)
)
if self.is_decoder:
self.layer.append(
TFT5LayerCrossAttention(
config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1",
config,
has_relative_attention_bias=has_relative_attention_bias,
name="layer_._1",
)
)
......@@ -539,7 +553,11 @@ class TFT5MainLayer(tf.keras.layers.Layer):
self.num_hidden_layers = config.num_layers
self.block = [
TFT5Block(config, has_relative_attention_bias=bool(i == 0), name="block_._{}".format(i),)
TFT5Block(
config,
has_relative_attention_bias=bool(i == 0),
name="block_._{}".format(i),
)
for i in range(config.num_layers)
]
self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm")
......@@ -654,7 +672,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None],
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
......@@ -765,7 +784,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
# pointers for your model.
####################################################
class TFT5PreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......
......@@ -628,7 +628,13 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
layer_outputs = layer(
core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions, training=training,
core_out,
pos_emb,
dec_attn_mask,
mems_i,
head_mask[i],
output_attentions,
training=training,
)
core_out = layer_outputs[0]
if output_attentions:
......@@ -657,12 +663,15 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None)
return TFTransfoXLModelOutput(
last_hidden_state=core_out, mems=new_mems, hidden_states=hids, attentions=attentions,
last_hidden_state=core_out,
mems=new_mems,
hidden_states=hids,
attentions=attentions,
)
class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
......@@ -852,8 +861,7 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
)
def get_output_embeddings(self):
""" Double-check if you are using adaptive softmax.
"""
"""Double-check if you are using adaptive softmax."""
if len(self.crit.out_layers) > 0:
return self.crit.out_layers[-1]
return None
......
......@@ -64,7 +64,10 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
else:
self.out_projs.append(None)
weight = self.add_weight(
shape=(self.vocab_size, self.d_embed,),
shape=(
self.vocab_size,
self.d_embed,
),
initializer="zeros",
trainable=True,
name="out_layers_._{}_._weight".format(i),
......@@ -86,7 +89,10 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
)
self.out_projs.append(weight)
weight = self.add_weight(
shape=(r_idx - l_idx, d_emb_i,),
shape=(
r_idx - l_idx,
d_emb_i,
),
initializer="zeros",
trainable=True,
name="out_layers_._{}_._weight".format(i),
......
......@@ -212,8 +212,7 @@ class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
.. note::
Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
"""
"""
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment