Unverified Commit 32dbb2d9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

make style (#11442)

parent 04ab2ca6
...@@ -71,7 +71,7 @@ ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -71,7 +71,7 @@ ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def load_tf_weights_in_albert(model, config, tf_checkpoint_path): def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.""" """Load tf checkpoints in a pytorch model."""
try: try:
import re import re
......
...@@ -189,7 +189,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -189,7 +189,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
class TFAlbertAttention(tf.keras.layers.Layer): class TFAlbertAttention(tf.keras.layers.Layer):
""" Contains the complete attention sublayer, including both dropouts and layer norm. """ """Contains the complete attention sublayer, including both dropouts and layer norm."""
def __init__(self, config: AlbertConfig, **kwargs): def __init__(self, config: AlbertConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -187,7 +187,7 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -187,7 +187,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
return outputs return outputs
def _tokenize(self, text, sample=False): def _tokenize(self, text, sample=False):
""" Tokenize a string. """ """Tokenize a string."""
text = self.preprocess_text(text) text = self.preprocess_text(text)
if not sample: if not sample:
...@@ -211,7 +211,7 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -211,7 +211,7 @@ class AlbertTokenizer(PreTrainedTokenizer):
return new_pieces return new_pieces
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.sp_model.PieceToId(token) return self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
......
...@@ -223,7 +223,7 @@ class BarthezTokenizer(PreTrainedTokenizer): ...@@ -223,7 +223,7 @@ class BarthezTokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text) return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids: if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token] return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token) spm_id = self.sp_model.PieceToId(token)
......
...@@ -703,7 +703,7 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -703,7 +703,7 @@ class BertPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -233,7 +233,7 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -233,7 +233,7 @@ class BertTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token)) return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -241,7 +241,7 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -241,7 +241,7 @@ class BertTokenizer(PreTrainedTokenizer):
return self.ids_to_tokens.get(index, self.unk_token) return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip() out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string return out_string
......
...@@ -177,7 +177,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel): ...@@ -177,7 +177,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -119,7 +119,7 @@ class BertGenerationTokenizer(PreTrainedTokenizer): ...@@ -119,7 +119,7 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
return pieces return pieces
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token) return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -128,7 +128,7 @@ class BertGenerationTokenizer(PreTrainedTokenizer): ...@@ -128,7 +128,7 @@ class BertGenerationTokenizer(PreTrainedTokenizer):
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens) out_string = self.sp_model.decode_pieces(tokens)
return out_string return out_string
......
...@@ -368,7 +368,7 @@ class BertweetTokenizer(PreTrainedTokenizer): ...@@ -368,7 +368,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
return token return token
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -376,7 +376,7 @@ class BertweetTokenizer(PreTrainedTokenizer): ...@@ -376,7 +376,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip() out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string return out_string
......
...@@ -484,7 +484,7 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -484,7 +484,7 @@ class BigBirdBlockSparseAttention(nn.Module):
@staticmethod @staticmethod
def torch_bmm_nd(inp_1, inp_2, ndim=None): def torch_bmm_nd(inp_1, inp_2, ndim=None):
""" Fast nd matrix multiplication """ """Fast nd matrix multiplication"""
# faster replacement of torch.einsum ("bhqk,bhkd->bhqd") # faster replacement of torch.einsum ("bhqk,bhkd->bhqd")
return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view( return torch.bmm(inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:])).view(
inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1]) inp_1.shape[: ndim - 2] + (inp_1.shape[ndim - 2], inp_2.shape[ndim - 1])
...@@ -492,7 +492,7 @@ class BigBirdBlockSparseAttention(nn.Module): ...@@ -492,7 +492,7 @@ class BigBirdBlockSparseAttention(nn.Module):
@staticmethod @staticmethod
def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None): def torch_bmm_nd_transpose(inp_1, inp_2, ndim=None):
""" Fast nd matrix multiplication with transpose """ """Fast nd matrix multiplication with transpose"""
# faster replacement of torch.einsum (bhqd,bhkd->bhqk) # faster replacement of torch.einsum (bhqd,bhkd->bhqk)
return torch.bmm( return torch.bmm(
inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2) inp_1.reshape((-1,) + inp_1.shape[-2:]), inp_2.reshape((-1,) + inp_2.shape[-2:]).transpose(1, 2)
...@@ -1743,7 +1743,7 @@ class BigBirdPreTrainedModel(PreTrainedModel): ...@@ -1743,7 +1743,7 @@ class BigBirdPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -149,7 +149,7 @@ class BigBirdTokenizer(PreTrainedTokenizer): ...@@ -149,7 +149,7 @@ class BigBirdTokenizer(PreTrainedTokenizer):
return pieces return pieces
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token) return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -158,7 +158,7 @@ class BigBirdTokenizer(PreTrainedTokenizer): ...@@ -158,7 +158,7 @@ class BigBirdTokenizer(PreTrainedTokenizer):
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
out_string = self.sp_model.decode_pieces(tokens) out_string = self.sp_model.decode_pieces(tokens)
return out_string return out_string
......
...@@ -183,7 +183,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer): ...@@ -183,7 +183,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
return " ".join(words) return " ".join(words)
def _tokenize(self, text: str) -> List[str]: def _tokenize(self, text: str) -> List[str]:
""" Split a string into tokens using BPE.""" """Split a string into tokens using BPE."""
split_tokens = [] split_tokens = []
words = re.findall(r"\S+\n?", text) words = re.findall(r"\S+\n?", text)
...@@ -193,7 +193,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer): ...@@ -193,7 +193,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token: str) -> int: def _convert_token_to_id(self, token: str) -> int:
""" Converts a token to an id using the vocab. """ """Converts a token to an id using the vocab."""
token = token.lower() token = token.lower()
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
...@@ -202,7 +202,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer): ...@@ -202,7 +202,7 @@ class BlenderbotSmallTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
""" Converts a sequence of tokens in a single string. """ """Converts a sequence of tokens in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip() out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string return out_string
......
...@@ -222,7 +222,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -222,7 +222,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text) return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids: if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token] return self.fairseq_tokens_to_ids[token]
elif self.sp_model.PieceToId(token) == 0: elif self.sp_model.PieceToId(token) == 0:
......
...@@ -238,7 +238,7 @@ class ConvBertPreTrainedModel(PreTrainedModel): ...@@ -238,7 +238,7 @@ class ConvBertPreTrainedModel(PreTrainedModel):
authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"] authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -212,7 +212,7 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -212,7 +212,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -220,7 +220,7 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -220,7 +220,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace("@@ ", "").strip() out_string = " ".join(tokens).replace("@@ ", "").strip()
return out_string return out_string
......
...@@ -134,7 +134,7 @@ class DebertaV2Tokenizer(PreTrainedTokenizer): ...@@ -134,7 +134,7 @@ class DebertaV2Tokenizer(PreTrainedTokenizer):
return self._tokenizer.tokenize(text) return self._tokenizer.tokenize(text)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self._tokenizer.spm.PieceToId(token) return self._tokenizer.spm.PieceToId(token)
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -142,7 +142,7 @@ class DebertaV2Tokenizer(PreTrainedTokenizer): ...@@ -142,7 +142,7 @@ class DebertaV2Tokenizer(PreTrainedTokenizer):
return self._tokenizer.spm.IdToPiece(index) if index < self.vocab_size else self.unk_token return self._tokenizer.spm.IdToPiece(index) if index < self.vocab_size else self.unk_token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
return self._tokenizer.decode(tokens) return self._tokenizer.decode(tokens)
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......
...@@ -386,7 +386,7 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -386,7 +386,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
base_model_prefix = "deit" base_model_prefix = "deit"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -167,11 +167,11 @@ class MultiHeadSelfAttention(nn.Module): ...@@ -167,11 +167,11 @@ class MultiHeadSelfAttention(nn.Module):
mask_reshp = (bs, 1, 1, k_length) mask_reshp = (bs, 1, 1, k_length)
def shape(x): def shape(x):
""" separate heads """ """separate heads"""
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
def unshape(x): def unshape(x):
""" group heads """ """group heads"""
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
......
...@@ -175,11 +175,11 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer): ...@@ -175,11 +175,11 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
mask_reshape = [bs, 1, 1, k_length] mask_reshape = [bs, 1, 1, k_length]
def shape(x): def shape(x):
""" separate heads """ """separate heads"""
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3)) return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
def unshape(x): def unshape(x):
""" group heads """ """group heads"""
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
......
...@@ -653,7 +653,7 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -653,7 +653,7 @@ class ElectraPreTrainedModel(PreTrainedModel):
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment