Unverified Commit 32dbb2d9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

make style (#11442)

parent 04ab2ca6
...@@ -342,11 +342,11 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer): ...@@ -342,11 +342,11 @@ class TFFlaubertMultiHeadAttention(tf.keras.layers.Layer):
mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen) mask_reshape = (bs, 1, qlen, klen) if len(shape_list(mask)) == 3 else (bs, 1, 1, klen)
def shape(x): def shape(x):
""" projection """ """projection"""
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3)) return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
def unshape(x): def unshape(x):
""" compute context """ """compute context"""
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head) q = shape(self.q_lin(input)) # (bs, n_heads, qlen, dim_per_head)
......
...@@ -374,7 +374,7 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -374,7 +374,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -382,7 +382,7 @@ class FSMTTokenizer(PreTrainedTokenizer): ...@@ -382,7 +382,7 @@ class FSMTTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
# remove BPE # remove BPE
tokens = [t.replace(" ", "").replace("</w>", " ") for t in tokens] tokens = [t.replace(" ", "").replace("</w>", " ") for t in tokens]
......
...@@ -188,7 +188,7 @@ class FunnelAttentionStructure(nn.Module): ...@@ -188,7 +188,7 @@ class FunnelAttentionStructure(nn.Module):
self.pooling_mult = None self.pooling_mult = None
def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None): def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None):
""" Returns the attention inputs associated to the inputs of the model. """ """Returns the attention inputs associated to the inputs of the model."""
# inputs_embeds has shape batch_size x seq_len x d_model # inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len # attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1 self.pooling_mult = 1
...@@ -383,7 +383,7 @@ class FunnelAttentionStructure(nn.Module): ...@@ -383,7 +383,7 @@ class FunnelAttentionStructure(nn.Module):
return tensor return tensor
def pre_attention_pooling(self, output, attention_inputs): def pre_attention_pooling(self, output, attention_inputs):
""" Pool `output` and the proper parts of `attention_inputs` before the attention layer. """ """Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.config.pool_q_only: if self.config.pool_q_only:
if self.config.attention_type == "factorized": if self.config.attention_type == "factorized":
...@@ -403,7 +403,7 @@ class FunnelAttentionStructure(nn.Module): ...@@ -403,7 +403,7 @@ class FunnelAttentionStructure(nn.Module):
return output, attention_inputs return output, attention_inputs
def post_attention_pooling(self, attention_inputs): def post_attention_pooling(self, attention_inputs):
""" Pool the proper parts of `attention_inputs` after the attention layer. """ """Pool the proper parts of `attention_inputs` after the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.config.pool_q_only: if self.config.pool_q_only:
self.pooling_mult *= 2 self.pooling_mult *= 2
...@@ -457,7 +457,7 @@ class FunnelRelMultiheadAttention(nn.Module): ...@@ -457,7 +457,7 @@ class FunnelRelMultiheadAttention(nn.Module):
self.scale = 1.0 / (d_head ** 0.5) self.scale = 1.0 / (d_head ** 0.5)
def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None): def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
""" Relative attention score for the positional encodings """ """Relative attention score for the positional encodings"""
# q_head has shape batch_size x sea_len x n_head x d_head # q_head has shape batch_size x sea_len x n_head x d_head
if self.config.attention_type == "factorized": if self.config.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula (https://arxiv.org/abs/2006.03236) # Notations from the paper, appending A.2.2, final formula (https://arxiv.org/abs/2006.03236)
...@@ -499,7 +499,7 @@ class FunnelRelMultiheadAttention(nn.Module): ...@@ -499,7 +499,7 @@ class FunnelRelMultiheadAttention(nn.Module):
return positional_attn return positional_attn
def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None): def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
""" Relative attention score for the token_type_ids """ """Relative attention score for the token_type_ids"""
if token_type_mat is None: if token_type_mat is None:
return 0 return 0
batch_size, seq_len, context_len = token_type_mat.shape batch_size, seq_len, context_len = token_type_mat.shape
......
...@@ -139,7 +139,7 @@ class TFFunnelAttentionStructure: ...@@ -139,7 +139,7 @@ class TFFunnelAttentionStructure:
self.pooling_mult = None self.pooling_mult = None
def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None, training=False): def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None, training=False):
""" Returns the attention inputs associated to the inputs of the model. """ """Returns the attention inputs associated to the inputs of the model."""
# inputs_embeds has shape batch_size x seq_len x d_model # inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len # attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1 self.pooling_mult = 1
...@@ -328,7 +328,7 @@ class TFFunnelAttentionStructure: ...@@ -328,7 +328,7 @@ class TFFunnelAttentionStructure:
return tf.squeeze(tensor, 2) if ndim == 2 else tensor return tf.squeeze(tensor, 2) if ndim == 2 else tensor
def pre_attention_pooling(self, output, attention_inputs): def pre_attention_pooling(self, output, attention_inputs):
""" Pool `output` and the proper parts of `attention_inputs` before the attention layer. """ """Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.pool_q_only: if self.pool_q_only:
if self.attention_type == "factorized": if self.attention_type == "factorized":
...@@ -348,7 +348,7 @@ class TFFunnelAttentionStructure: ...@@ -348,7 +348,7 @@ class TFFunnelAttentionStructure:
return output, attention_inputs return output, attention_inputs
def post_attention_pooling(self, attention_inputs): def post_attention_pooling(self, attention_inputs):
""" Pool the proper parts of `attention_inputs` after the attention layer. """ """Pool the proper parts of `attention_inputs` after the attention layer."""
position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
if self.pool_q_only: if self.pool_q_only:
self.pooling_mult *= 2 self.pooling_mult *= 2
...@@ -424,7 +424,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): ...@@ -424,7 +424,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None): def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
""" Relative attention score for the positional encodings """ """Relative attention score for the positional encodings"""
# q_head has shape batch_size x sea_len x n_head x d_head # q_head has shape batch_size x sea_len x n_head x d_head
if self.attention_type == "factorized": if self.attention_type == "factorized":
# Notations from the paper, appending A.2.2, final formula (https://arxiv.org/abs/2006.03236) # Notations from the paper, appending A.2.2, final formula (https://arxiv.org/abs/2006.03236)
...@@ -470,7 +470,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): ...@@ -470,7 +470,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
return positional_attn return positional_attn
def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None): def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
""" Relative attention score for the token_type_ids """ """Relative attention score for the token_type_ids"""
if token_type_mat is None: if token_type_mat is None:
return 0 return 0
batch_size, seq_len, context_len = shape_list(token_type_mat) batch_size, seq_len, context_len = shape_list(token_type_mat)
...@@ -723,7 +723,7 @@ class TFFunnelDecoder(tf.keras.layers.Layer): ...@@ -723,7 +723,7 @@ class TFFunnelDecoder(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFFunnelBaseLayer(tf.keras.layers.Layer): class TFFunnelBaseLayer(tf.keras.layers.Layer):
""" Base model without decoder """ """Base model without decoder"""
config_class = FunnelConfig config_class = FunnelConfig
...@@ -807,7 +807,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer): ...@@ -807,7 +807,7 @@ class TFFunnelBaseLayer(tf.keras.layers.Layer):
@keras_serializable @keras_serializable
class TFFunnelMainLayer(tf.keras.layers.Layer): class TFFunnelMainLayer(tf.keras.layers.Layer):
""" Base model with decoder """ """Base model with decoder"""
config_class = FunnelConfig config_class = FunnelConfig
......
...@@ -242,7 +242,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -242,7 +242,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return word return word
def _tokenize(self, text): def _tokenize(self, text):
""" Tokenize a string. """ """Tokenize a string."""
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
token = "".join( token = "".join(
...@@ -252,7 +252,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -252,7 +252,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return bpe_tokens return bpe_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -260,7 +260,7 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -260,7 +260,7 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return self.decoder.get(index) return self.decoder.get(index)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens) text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text return text
......
...@@ -645,7 +645,7 @@ class IBertPreTrainedModel(PreTrainedModel): ...@@ -645,7 +645,7 @@ class IBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "ibert" base_model_prefix = "ibert"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, (QuantLinear, nn.Linear)): if isinstance(module, (QuantLinear, nn.Linear)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -611,7 +611,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): ...@@ -611,7 +611,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -509,7 +509,7 @@ class LEDEncoderSelfAttention(nn.Module): ...@@ -509,7 +509,7 @@ class LEDEncoderSelfAttention(nn.Module):
@staticmethod @staticmethod
def _get_global_attn_indices(is_index_global_attn): def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """ """compute global attn indices required throughout forward pass"""
# helper variable # helper variable
num_global_attn_indices = is_index_global_attn.long().sum(dim=1) num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
......
...@@ -670,7 +670,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer): ...@@ -670,7 +670,7 @@ class TFLEDEncoderSelfAttention(tf.keras.layers.Layer):
@staticmethod @staticmethod
def _get_global_attn_indices(is_index_global_attn): def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """ """compute global attn indices required throughout forward pass"""
# helper variable # helper variable
num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
......
...@@ -899,7 +899,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -899,7 +899,7 @@ class LongformerSelfAttention(nn.Module):
@staticmethod @staticmethod
def _get_global_attn_indices(is_index_global_attn): def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """ """compute global attn indices required throughout forward pass"""
# helper variable # helper variable
num_global_attn_indices = is_index_global_attn.long().sum(dim=1) num_global_attn_indices = is_index_global_attn.long().sum(dim=1)
...@@ -1363,7 +1363,7 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1363,7 +1363,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -1189,7 +1189,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer): ...@@ -1189,7 +1189,7 @@ class TFLongformerSelfAttention(tf.keras.layers.Layer):
@staticmethod @staticmethod
def _get_global_attn_indices(is_index_global_attn): def _get_global_attn_indices(is_index_global_attn):
""" compute global attn indices required throughout forward pass """ """compute global attn indices required throughout forward pass"""
# helper variable # helper variable
num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1) num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype) num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
......
...@@ -783,7 +783,7 @@ class LxmertPreTrainedModel(PreTrainedModel): ...@@ -783,7 +783,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
base_model_prefix = "lxmert" base_model_prefix = "lxmert"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -227,7 +227,7 @@ class MarianTokenizer(PreTrainedTokenizer): ...@@ -227,7 +227,7 @@ class MarianTokenizer(PreTrainedTokenizer):
return super().decode(token_ids, **kwargs) return super().decode(token_ids, **kwargs)
def convert_tokens_to_string(self, tokens: List[str]) -> str: def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise """ """Uses source spm if _decode_use_source_tokenizer is True, and target spm otherwise"""
if self._decode_use_source_tokenizer: if self._decode_use_source_tokenizer:
return self.spm_source.DecodePieces(tokens) return self.spm_source.DecodePieces(tokens)
else: else:
......
...@@ -189,7 +189,7 @@ class MBart50Tokenizer(PreTrainedTokenizer): ...@@ -189,7 +189,7 @@ class MBart50Tokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text) return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token: str) -> int: def _convert_token_to_id(self, token: str) -> int:
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
if token in self.fairseq_tokens_to_ids: if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token] return self.fairseq_tokens_to_ids[token]
spm_id = self.sp_model.PieceToId(token) spm_id = self.sp_model.PieceToId(token)
......
...@@ -708,7 +708,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel): ...@@ -708,7 +708,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -669,7 +669,7 @@ class MobileBertPreTrainedModel(PreTrainedModel): ...@@ -669,7 +669,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -56,7 +56,7 @@ class MPNetPreTrainedModel(PreTrainedModel): ...@@ -56,7 +56,7 @@ class MPNetPreTrainedModel(PreTrainedModel):
base_model_prefix = "mpnet" base_model_prefix = "mpnet"
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """Initialize the weights"""
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
......
...@@ -430,7 +430,7 @@ class TFMPNetEncoder(tf.keras.layers.Layer): ...@@ -430,7 +430,7 @@ class TFMPNetEncoder(tf.keras.layers.Layer):
return ret return ret
def compute_position_bias(self, x, position_ids=None): def compute_position_bias(self, x, position_ids=None):
""" Compute binned relative position bias """ """Compute binned relative position bias"""
input_shape = shape_list(x) input_shape = shape_list(x)
qlen, klen = input_shape[1], input_shape[1] qlen, klen = input_shape[1], input_shape[1]
......
...@@ -210,7 +210,7 @@ class MPNetTokenizer(PreTrainedTokenizer): ...@@ -210,7 +210,7 @@ class MPNetTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.vocab.get(token, self.vocab.get(self.unk_token)) return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -218,7 +218,7 @@ class MPNetTokenizer(PreTrainedTokenizer): ...@@ -218,7 +218,7 @@ class MPNetTokenizer(PreTrainedTokenizer):
return self.ids_to_tokens.get(index, self.unk_token) return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
out_string = " ".join(tokens).replace(" ##", "").strip() out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string return out_string
......
...@@ -176,7 +176,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -176,7 +176,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return word return word
def _tokenize(self, text): def _tokenize(self, text):
""" Tokenize a string. """ """Tokenize a string."""
split_tokens = [] split_tokens = []
if self.fix_text is None: if self.fix_text is None:
# Using BERT's BasicTokenizer # Using BERT's BasicTokenizer
...@@ -191,7 +191,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -191,7 +191,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
...@@ -199,7 +199,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -199,7 +199,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """Converts a sequence of tokens (string) in a single string."""
out_string = "".join(tokens).replace("</w>", " ").strip() out_string = "".join(tokens).replace("</w>", " ").strip()
return out_string return out_string
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment