Unverified Commit 32dbb2d9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

make style (#11442)

parent 04ab2ca6
......@@ -71,7 +71,7 @@ ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model."""
"""Load tf checkpoints in a pytorch model."""
try:
import re
......
......@@ -189,7 +189,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
class TFAlbertAttention(tf.keras.layers.Layer):
""" Contains the complete attention sublayer, including both dropouts and layer norm. """
"""Contains the complete attention sublayer, including both dropouts and layer norm."""
def __init__(self, config: AlbertConfig, **kwargs):
super().__init__(**kwargs)
......
......@@ -187,7 +187,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
return outputs
def _tokenize(self, text, sample=False):
""" Tokenize a string. """
"""Tokenize a string."""
text = self.preprocess_text(text)
if not sample:
......@@ -211,7 +211,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
return new_pieces
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index):
......
......@@ -223,7 +223,7 @@ class BarthezTokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token)
......
......@@ -703,7 +703,7 @@ class BertPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -233,7 +233,7 @@ class BertTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -241,7 +241,7 @@ class BertTokenizer(PreTrainedTokenizer):
return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string
......
......@@ -177,7 +177,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -119,7 +119,7 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
......@@ -128,7 +128,7 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
......
......@@ -368,7 +368,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
return token
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -376,7 +376,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string
......
......@@ -484,7 +484,7 @@ class BigBirdBlockSparseAttention(nn.Module):
@staticmethod
def torch_bmm_nd(inp_1, inp_2, ndim=None):
""" Fast nd matrix multiplication """
"""Fast nd matrix multiplication"""
# faster replacement of torch.einsum ("bhqk,bhkd->bhqd")
return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view(
inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1])
......@@ -492,7 +492,7 @@ class BigBirdBlockSparseAttention(nn.Module):
@staticmethod
def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None):
""" Fast nd matrix multiplication with transpose """
"""Fast nd matrix multiplication with transpose"""
# faster replacement of torch.einsum (bhqd,bhkd->bhqk)
return torch.bmm(
inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2)
......@@ -1743,7 +1743,7 @@ class BigBirdPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -149,7 +149,7 @@ class BigBirdTokenizer(PreTrainedTokenizer):
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
......@@ -158,7 +158,7 @@ class BigBirdTokenizer(PreTrainedTokenizer):
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens)
return out_string
......
......@@ -183,7 +183,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
return " ".join(words)
def _tokenize(self, text: str) -> List[str]:
""" Split a string into tokens using BPE."""
"""Split a string into tokens using BPE."""
split_tokens = []
words = re.findall(r"\S+\n?", text)
......@@ -193,7 +193,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token: str) -> int:
""" Converts a token to an id using the vocab. """
"""Converts a token to an id using the vocab."""
token = token.lower()
return self.encoder.get(token, self.encoder.get(self.unk_token))
......@@ -202,7 +202,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
""" Converts a sequence of tokens in a single string. """
"""Converts a sequence of tokens in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string
......
......@@ -222,7 +222,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token]
elif self.sp_model.PieceToId(token) == 0:
......
......@@ -238,7 +238,7 @@ class ConvBertPreTrainedModel(PreTrainedModel):
authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"]
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -212,7 +212,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
return split_tokens
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
......@@ -220,7 +220,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string
......
......@@ -134,7 +134,7 @@ class DebertaV2Tokenizer(PreTrainedTokenizer):
return self._tokenizer.tokenize(text)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
"""Converts a token (str) in an id using the vocab."""
return self._tokenizer.spm.PieceToId(token)
def _convert_id_to_token(self, index):
......@@ -142,7 +142,7 @@ class DebertaV2Tokenizer(PreTrainedTokenizer):
return self._tokenizer.spm.IdToPiece(index) if index < self.vocab_size else self.unk_token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
"""Converts a sequence of tokens (string) in a single string."""
return self._tokenizer.decode(tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......
......@@ -386,7 +386,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
base_model_prefix = "deit"
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
......@@ -167,11 +167,11 @@ class MultiHeadSelfAttention(nn.Module):
mask_reshp = (bs, 1, 1, k_length)
def shape(x):
""" separate heads """
"""separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x):
""" group heads """
"""group heads"""
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
......
......@@ -175,11 +175,11 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
mask_reshape = [bs, 1, 1, k_length]
def shape(x):
""" separate heads """
"""separate heads"""
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
def unshape(x):
""" group heads """
"""group heads"""
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
......
......@@ -653,7 +653,7 @@ class ElectraPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment