Unverified Commit 31c56f2e authored by Anthony MOI's avatar Anthony MOI
Browse files

Fix style

parent 951ae99b
...@@ -20,7 +20,7 @@ import logging ...@@ -20,7 +20,7 @@ import logging
import os import os
import unicodedata import unicodedata
from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -526,42 +526,64 @@ def _is_punctuation(char): ...@@ -526,42 +526,64 @@ def _is_punctuation(char):
return True return True
return False return False
class BertTokenizerFast(FastPreTrainedTokenizer): class BertTokenizerFast(FastPreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, def __init__(
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", self,
mask_token="[MASK]", tokenize_chinese_chars=True, vocab_file,
max_length=None, pad_to_max_length=False, stride=0, do_lower_case=True,
truncation_strategy='longest_first', add_special_tokens=True, **kwargs): do_basic_tokenize=True,
never_split=None,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
max_length=None,
pad_to_max_length=False,
stride=0,
truncation_strategy="longest_first",
add_special_tokens=True,
**kwargs
):
try: try:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
super(BertTokenizerFast, self).__init__(unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, **kwargs)
self._tokenizer = Tokenizer(models.WordPiece.from_files( super(BertTokenizerFast, self).__init__(
vocab_file, unk_token=unk_token,
unk_token=unk_token sep_token=sep_token,
)) pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self._tokenizer = Tokenizer(models.WordPiece.from_files(vocab_file, unk_token=unk_token))
self._update_special_tokens() self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new( self._tokenizer.with_pre_tokenizer(
pre_tokenizers.BertPreTokenizer.new(
do_basic_tokenize=do_basic_tokenize, do_basic_tokenize=do_basic_tokenize,
do_lower_case=do_lower_case, do_lower_case=do_lower_case,
tokenize_chinese_chars=tokenize_chinese_chars, tokenize_chinese_chars=tokenize_chinese_chars,
never_split=never_split if never_split is not None else [], never_split=never_split if never_split is not None else [],
)) )
)
self._tokenizer.with_decoder(decoders.WordPiece.new()) self._tokenizer.with_decoder(decoders.WordPiece.new())
if add_special_tokens: if add_special_tokens:
self._tokenizer.with_post_processor(processors.BertProcessing.new( self._tokenizer.with_post_processor(
processors.BertProcessing.new(
(sep_token, self._tokenizer.token_to_id(sep_token)), (sep_token, self._tokenizer.token_to_id(sep_token)),
(cls_token, self._tokenizer.token_to_id(cls_token)), (cls_token, self._tokenizer.token_to_id(cls_token)),
)) )
)
if max_length is not None: if max_length is not None:
self._tokenizer.with_truncation(max_length, stride, truncation_strategy) self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
self._tokenizer.with_padding( self._tokenizer.with_padding(
...@@ -569,7 +591,7 @@ class BertTokenizerFast(FastPreTrainedTokenizer): ...@@ -569,7 +591,7 @@ class BertTokenizerFast(FastPreTrainedTokenizer):
self.padding_side, self.padding_side,
self.pad_token_id, self.pad_token_id,
self.pad_token_type_id, self.pad_token_type_id,
self.pad_token self.pad_token,
) )
self._decoder = decoders.WordPiece.new() self._decoder = decoders.WordPiece.new()
......
...@@ -22,7 +22,7 @@ from functools import lru_cache ...@@ -22,7 +22,7 @@ from functools import lru_cache
import regex as re import regex as re
from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer from .tokenization_utils import FastPreTrainedTokenizer, PreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -247,19 +247,33 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -247,19 +247,33 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return vocab_file, merge_file return vocab_file, merge_file
class GPT2TokenizerFast(FastPreTrainedTokenizer): class GPT2TokenizerFast(FastPreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<|endoftext|>", bos_token="<|endoftext|>", def __init__(
eos_token="<|endoftext|>", pad_to_max_length=False, add_prefix_space=False, self,
max_length=None, stride=0, truncation_strategy='longest_first', **kwargs): vocab_file,
merges_file,
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_to_max_length=False,
add_prefix_space=False,
max_length=None,
stride=0,
truncation_strategy="longest_first",
**kwargs
):
try: try:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders from tokenizers import Tokenizer, models, pre_tokenizers, decoders
super(GPT2TokenizerFast, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) super(GPT2TokenizerFast, self).__init__(
bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs
)
self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file)) self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file))
self._update_special_tokens() self._update_special_tokens()
...@@ -272,7 +286,7 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer): ...@@ -272,7 +286,7 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer):
self.padding_side, self.padding_side,
self.pad_token_id if self.pad_token_id is not None else 0, self.pad_token_id if self.pad_token_id is not None else 0,
self.pad_token_type_id, self.pad_token_type_id,
self.pad_token if self.pad_token is not None else "" self.pad_token if self.pad_token is not None else "",
) )
self._decoder = decoders.ByteLevel.new() self._decoder = decoders.ByteLevel.new()
......
...@@ -1411,6 +1411,7 @@ class PreTrainedTokenizer(object): ...@@ -1411,6 +1411,7 @@ class PreTrainedTokenizer(object):
) )
return out_string return out_string
class FastPreTrainedTokenizer(PreTrainedTokenizer): class FastPreTrainedTokenizer(PreTrainedTokenizer):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(FastPreTrainedTokenizer, self).__init__(**kwargs) super(FastPreTrainedTokenizer, self).__init__(**kwargs)
...@@ -1438,12 +1439,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): ...@@ -1438,12 +1439,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
self.tokenizer.add_special_tokens(self.all_special_tokens) self.tokenizer.add_special_tokens(self.all_special_tokens)
@staticmethod @staticmethod
def _convert_encoding(encoding, def _convert_encoding(
encoding,
return_tensors=None, return_tensors=None,
return_token_type_ids=True, return_token_type_ids=True,
return_attention_mask=True, return_attention_mask=True,
return_overflowing_tokens=False, return_overflowing_tokens=False,
return_special_tokens_mask=False): return_special_tokens_mask=False,
):
encoding_dict = { encoding_dict = {
"input_ids": encoding.ids, "input_ids": encoding.ids,
} }
...@@ -1458,14 +1461,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): ...@@ -1458,14 +1461,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask encoding_dict["special_tokens_mask"] = encoding.special_tokens_mask
# Prepare inputs as tensors if asked # Prepare inputs as tensors if asked
if return_tensors == 'tf' and is_tf_available(): if return_tensors == "tf" and is_tf_available():
encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]]) encoding_dict["input_ids"] = tf.constant([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]]) encoding_dict["token_type_ids"] = tf.constant([encoding_dict["token_type_ids"]])
if "attention_mask" in encoding_dict: if "attention_mask" in encoding_dict:
encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]]) encoding_dict["attention_mask"] = tf.constant([encoding_dict["attention_mask"]])
elif return_tensors == 'pt' and is_torch_available(): elif return_tensors == "pt" and is_torch_available():
encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]]) encoding_dict["input_ids"] = torch.tensor([encoding_dict["input_ids"]])
encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]]) encoding_dict["token_type_ids"] = torch.tensor([encoding_dict["token_type_ids"]])
...@@ -1474,11 +1477,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): ...@@ -1474,11 +1477,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
elif return_tensors is not None: elif return_tensors is not None:
logger.warning( logger.warning(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
return_tensors)) return_tensors
)
)
return encoding_dict return encoding_dict
def encode_plus(self, def encode_plus(
self,
text, text,
text_pair=None, text_pair=None,
return_tensors=None, return_tensors=None,
...@@ -1486,14 +1492,17 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): ...@@ -1486,14 +1492,17 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
return_attention_mask=True, return_attention_mask=True,
return_overflowing_tokens=False, return_overflowing_tokens=False,
return_special_tokens_mask=False, return_special_tokens_mask=False,
**kwargs): **kwargs
):
encoding = self.tokenizer.encode(text, text_pair) encoding = self.tokenizer.encode(text, text_pair)
return self._convert_encoding(encoding, return self._convert_encoding(
encoding,
return_tensors=return_tensors, return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids, return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask, return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens, return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask) return_special_tokens_mask=return_special_tokens_mask,
)
def tokenize(self, text): def tokenize(self, text):
return self.tokenizer.encode(text).tokens return self.tokenizer.encode(text).tokens
...@@ -1510,19 +1519,26 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): ...@@ -1510,19 +1519,26 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
def add_tokens(self, new_tokens): def add_tokens(self, new_tokens):
self.tokenizer.add_tokens(new_tokens) self.tokenizer.add_tokens(new_tokens)
def encode_batch(self, texts, def encode_batch(
self,
texts,
return_tensors=None, return_tensors=None,
return_token_type_ids=True, return_token_type_ids=True,
return_attention_mask=True, return_attention_mask=True,
return_overflowing_tokens=False, return_overflowing_tokens=False,
return_special_tokens_mask=False): return_special_tokens_mask=False,
return [self._convert_encoding(encoding, ):
return [
self._convert_encoding(
encoding,
return_tensors=return_tensors, return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids, return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask, return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens, return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask) return_special_tokens_mask=return_special_tokens_mask,
for encoding in self.tokenizer.encode_batch(texts)] )
for encoding in self.tokenizer.encode_batch(texts)
]
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = self.tokenizer.decode(token_ids, skip_special_tokens) text = self.tokenizer.decode(token_ids, skip_special_tokens)
...@@ -1534,6 +1550,7 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): ...@@ -1534,6 +1550,7 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
return text return text
def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True): def decode_batch(self, ids_batch, skip_special_tokens=False, clear_up_tokenization_spaces=True):
return [self.clean_up_tokenization(text) return [
if clear_up_tokenization_spaces else text self.clean_up_tokenization(text) if clear_up_tokenization_spaces else text
for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)] for text in self.tokenizer.decode_batch(ids_batch, skip_special_tokens)
\ No newline at end of file ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment