Unverified Commit 32dbb2d9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

make style (#11442)

parent 04ab2ca6
......@@ -175,7 +175,7 @@ class PegasusTokenizer(PreTrainedTokenizer):
return pieces
def _convert_token_to_id(self, token: str) -> int:
""" Converts a token (str) to an id using the vocab. """
"""Converts a token (str) to an id using the vocab."""
if token in self.decoder:
return self.decoder[token]
elif token in self.added_tokens_decoder:
......@@ -194,7 +194,7 @@ class PegasusTokenizer(PreTrainedTokenizer):
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
......
......@@ -295,7 +295,7 @@ class PhobertTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -303,7 +303,7 @@ class PhobertTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string
......
......@@ -172,7 +172,7 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -180,7 +180,7 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
......
......@@ -1779,7 +1779,7 @@ class ReformerPreTrainedModel(PreTrainedModel):
return dummy_inputs
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, AxialPositionEmbeddings):
for weight in module.weights:
torch.nn.init.normal_(weight, std=self.config.axial_norm_std)
......
......@@ -115,7 +115,7 @@ class ReformerTokenizer(PreTrainedTokenizer):
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
......@@ -125,7 +125,7 @@ class ReformerTokenizer(PreTrainedTokenizer):
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
......
......@@ -50,7 +50,7 @@ class RetriBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "retribert"
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
......
......@@ -574,7 +574,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -431,7 +431,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv1d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -389,7 +389,7 @@ class T5Attention(nn.Module):
return relative_buckets
def compute_bias(self, query_length, key_length):
""" Compute binned relative position bias """
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
......@@ -436,15 +436,15 @@ class T5Attention(nn.Module):
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
""" projection """
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
""" reshape """
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
""" projects hidden states correctly to key/query states """
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
......@@ -718,7 +718,7 @@ class T5PreTrainedModel(PreTrainedModel):
return dummy_inputs
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, T5LayerNorm):
module.weight.data.fill_(factor * 1.0)
......
......@@ -80,7 +80,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
self.variance_epsilon = epsilon
def build(self, input_shape):
"""Build shared word embedding layer """
"""Build shared word embedding layer"""
self.weight = self.add_weight("weight", shape=(input_shape[-1],), initializer="ones")
super().build(input_shape)
......@@ -230,7 +230,7 @@ class TFT5Attention(tf.keras.layers.Layer):
return relative_buckets
def compute_bias(self, query_length, key_length):
""" Compute binned relative position bias """
"""Compute binned relative position bias"""
context_position = tf.range(query_length)[:, None]
memory_position = tf.range(key_length)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
......@@ -279,17 +279,17 @@ class TFT5Attention(tf.keras.layers.Layer):
key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1]
def shape(hidden_states):
""" projection """
"""projection"""
return tf.transpose(
tf.reshape(hidden_states, (batch_size, -1, self.n_heads, self.key_value_proj_dim)), perm=(0, 2, 1, 3)
)
def unshape(hidden_states):
""" compute context """
"""compute context"""
return tf.reshape(tf.transpose(hidden_states, perm=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim))
def project(hidden_states, proj_layer, key_value_states, past_key_value):
""" projects hidden states correctly to key/query states """
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
......
......@@ -243,7 +243,7 @@ class T5Tokenizer(PreTrainedTokenizer):
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
if token.startswith("<extra_id_"):
match = re.match(r"<extra_id_(\d+)>", token)
num = int(match.group(1))
......@@ -259,7 +259,7 @@ class T5Tokenizer(PreTrainedTokenizer):
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
for token in tokens:
......
......@@ -699,7 +699,7 @@ class TapasPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -374,7 +374,7 @@ class TapasTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -382,7 +382,7 @@ class TapasTokenizer(PreTrainedTokenizer):
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
......
......@@ -434,7 +434,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return self.idx2sym[idx]
def _convert_token_to_id(self, sym):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
......
......@@ -372,7 +372,7 @@ class ViTPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -680,7 +680,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -151,11 +151,11 @@ class TFXLMMultiHeadAttention(tf.keras.layers.Layer):
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x):
""" projection """
"""projection"""
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
def unshape(x):
""" compute context """
"""compute context"""
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
......
......@@ -159,11 +159,11 @@ class MultiHeadAttention(nn.Module):
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
def shape(x):
""" projection """
"""projection"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
""" compute context """
"""compute context"""
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
......@@ -251,7 +251,7 @@ class XLMPreTrainedModel(PreTrainedModel):
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
def _init_weights(self, module):
""" Initialize the weights. """
"""Initialize the weights."""
if isinstance(module, nn.Embedding):
if self.config is not None and self.config.embed_init_std is not None:
nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
......
......@@ -847,7 +847,7 @@ class XLMTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -855,7 +855,7 @@ class XLMTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string
......
......@@ -245,7 +245,7 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment