Unverified Commit 041eac2d authored by Anthony MOI's avatar Anthony MOI
Browse files

GPT2TokenizerFast

parent 3471ff0d
...@@ -108,7 +108,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize ...@@ -108,7 +108,7 @@ from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenize
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer from .tokenization_roberta import RobertaTokenizer
from .tokenization_t5 import T5Tokenizer from .tokenization_t5 import T5Tokenizer
......
...@@ -22,7 +22,7 @@ from functools import lru_cache ...@@ -22,7 +22,7 @@ from functools import lru_cache
import regex as re import regex as re
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer, FastPreTrainedTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -246,3 +246,36 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -246,3 +246,36 @@ class GPT2Tokenizer(PreTrainedTokenizer):
index += 1 index += 1
return vocab_file, merge_file return vocab_file, merge_file
class GPT2TokenizerFast(FastPreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, merges_file, unk_token="<|endoftext|>", bos_token="<|endoftext|>",
eos_token="<|endoftext|>", pad_to_max_length=False, add_prefix_space=False,
max_length=None, stride=0, truncation_strategy='longest_first', **kwargs):
try:
from tokenizers import Tokenizer, models, pre_tokenizers, decoders
super(GPT2TokenizerFast, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
self._tokenizer = Tokenizer(models.BPE.from_files(vocab_file, merges_file))
self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(pre_tokenizers.ByteLevel.new(add_prefix_space))
self._tokenizer.with_decoder(decoders.ByteLevel.new())
if max_length:
self._tokenizer.with_truncation(max_length, stride, truncation_strategy)
self._tokenizer.with_padding(
max_length if pad_to_max_length else None,
self.padding_side,
self.pad_token_id if self.pad_token_id is not None else 0,
self.pad_token_type_id,
self.pad_token if self.pad_token is not None else ""
)
self._decoder = decoders.ByteLevel.new()
except (AttributeError, ImportError) as e:
logger.error("Make sure you installed `tokenizers` with `pip install tokenizers==0.0.8`")
raise e
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment