Unverified Commit 32dbb2d9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

make style (#11442)

parent 04ab2ca6
......@@ -342,11 +342,11 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x):
""" projection """
"""projection"""
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
def unshape(x):
""" compute context """
"""compute context"""
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
......
......@@ -374,7 +374,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -382,7 +382,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
# remove BPE
tokens = [t.replace(" ", "").replace("</w>", " ") for t in tokens]
......
......@@ -188,7 +188,7 @@ class FunnelAttentionStructure(nn.Module):
self.pooling_mult = None
def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None):
""" Returns the attention inputs associated to the inputs of the model. """
"""Returns the attention inputs associated to the inputs of the model."""
# inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
......@@ -383,7 +383,7 @@ class FunnelAttentionStructure(nn.Module):
return tensor
def pre_attention_pooling(self, output, attention_inputs):
""" Pool `output` and the proper parts of `attention_inputs` before the attention layer. """
"""Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.config.pool_q_only:
if self.config.attention_type == "factorized":
......@@ -403,7 +403,7 @@ class FunnelAttentionStructure(nn.Module):
return output, attention_inputs
def post_attention_pooling(self, attention_inputs):
""" Pool the proper parts of `attention_inputs` after the attention layer. """
"""Pool the proper parts of `attention_inputs` after the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.config.pool_q_only:
self.pooling_mult *= 2
......@@ -457,7 +457,7 @@ class FunnelRelMultiheadAttention(nn.Module):
self.scale = 1.0 / (d_head ** 0.5)
def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
""" Relative attention score for the positional encodings """
"""Relative attention score for the positional encodings"""
# q_head has shape batch_size x sea_len x n_head x d_head
if self.config.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula (https://arxiv.org/abs/2006.03236)
......@@ -499,7 +499,7 @@ class FunnelRelMultiheadAttention(nn.Module):
return positional_attn
def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
""" Relative attention score for the token_type_ids """
"""Relative attention score for the token_type_ids"""
if token_type_mat is None:
return 0
batch_size, seq_len, context_len = token_type_mat.shape
......
......@@ -139,7 +139,7 @@ class TFFunnelAttentionStructure:
self.pooling_mult = None
def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None, training=False):
""" Returns the attention inputs associated to the inputs of the model. """
"""Returns the attention inputs associated to the inputs of the model."""
# inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
......@@ -328,7 +328,7 @@ class TFFunnelAttentionStructure:
return tf.squeeze(tensor, 2) if ndim == 2 else tensor
def pre_attention_pooling(self, output, attention_inputs):
""" Pool `output` and the proper parts of `attention_inputs` before the attention layer. """
"""Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.pool_q_only:
if self.attention_type == "factorized":
......@@ -348,7 +348,7 @@ class TFFunnelAttentionStructure:
return output, attention_inputs
def post_attention_pooling(self, attention_inputs):
""" Pool the proper parts of `attention_inputs` after the attention layer. """
"""Pool the proper parts of `attention_inputs` after the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.pool_q_only:
self.pooling_mult *= 2
......@@ -424,7 +424,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
super().build(input_shape)
def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
""" Relative attention score for the positional encodings """
"""Relative attention score for the positional encodings"""
# q_head has shape batch_size x sea_len x n_head x d_head
if self.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula (https://arxiv.org/abs/2006.03236)
......@@ -470,7 +470,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
return positional_attn
def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
""" Relative attention score for the token_type_ids """
"""Relative attention score for the token_type_ids"""
if token_type_mat is None:
return 0
batch_size, seq_len, context_len = shape_list(token_type_mat)
......@@ -723,7 +723,7 @@ class TFFunnelDecoder(tf.keras.layers.Layer):
@keras_serializable
class TFFunnelBaseLayer(tf.keras.layers.Layer):
""" Base model without decoder """
"""Base model without decoder"""
config_class = FunnelConfig
......@@ -807,7 +807,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
@keras_serializable
class TFFunnelMainLayer(tf.keras.layers.Layer):
""" Base model with decoder """
"""Base model with decoder"""
config_class = FunnelConfig
......
......@@ -242,7 +242,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return word
def _tokenize(self, text):
""" Tokenize a string. """
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
......@@ -252,7 +252,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return bpe_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -260,7 +260,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
......
......@@ -645,7 +645,7 @@ class IBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "ibert"
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, (QuantLinear, nn.Linear)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -611,7 +611,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -509,7 +509,7 @@ class LEDEncoderSelfAttention(nn.Module):
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
"""compute global attn indices required throughout forward pass"""
# helper variable
num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
......
......@@ -670,7 +670,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
"""compute global attn indices required throughout forward pass"""
# helper variable
num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
......
......@@ -899,7 +899,7 @@ class LongformerSelfAttention(nn.Module):
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
"""compute global attn indices required throughout forward pass"""
# helper variable
num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
......@@ -1363,7 +1363,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -1189,7 +1189,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
@staticmethod
def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """
"""compute global attn indices required throughout forward pass"""
# helper variable
num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
......
......@@ -783,7 +783,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
base_model_prefix = "lxmert"
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -227,7 +227,7 @@ class MarianTokenizer(PreTrainedTokenizer):
return super().decode(token_ids, **kwargs)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise """
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
if self._decode_use_source_tokenizer:
return self.spm_source.DecodePieces(tokens)
else:
......
......@@ -189,7 +189,7 @@ class MBart50Tokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token: str) -> int:
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
......
......@@ -708,7 +708,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -669,7 +669,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -56,7 +56,7 @@ class MPNetPreTrainedModel(PreTrainedModel):
base_model_prefix = "mpnet"
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -430,7 +430,7 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
return ret
def compute_position_bias(self, x, position_ids=None):
""" Compute binned relative position bias """
"""Compute binned relative position bias"""
input_shape = shape_list(x)
qlen, klen = input_shape[1], input_shape[1]
......
......@@ -210,7 +210,7 @@ class MPNetTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -218,7 +218,7 @@ class MPNetTokenizer(PreTrainedTokenizer):
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
......
......@@ -176,7 +176,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return word
def _tokenize(self, text):
""" Tokenize a string. """
"""Tokenize a string."""
split_tokens = []
if self.fix_text is None:
# Using BERT's BasicTokenizer
......@@ -191,7 +191,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -199,7 +199,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment