Commit 67fabc55 authored by zihanl's avatar zihanl
Browse files

update tokenizer.py

parent cf4be127
...@@ -40,8 +40,7 @@ def build_tokenizer(args): ...@@ -40,8 +40,7 @@ def build_tokenizer(args):
vocab_extra_ids=args.vocab_extra_ids) vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file, special_tokens=args.spec_toks) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
# tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
else: else:
raise NotImplementedError('{} tokenizer is not ' raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type)) 'implemented.'.format(args.tokenizer_type))
...@@ -261,25 +260,14 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -261,25 +260,14 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
class _GPT2BPETokenizer(AbstractTokenizer): class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer.""" """Original GPT2 BPE tokenizer."""
def __init__(self, vocab_file, merge_file, special_tokens=None): def __init__(self, vocab_file, merge_file):
name = 'GPT2 BPE' name = 'GPT2 BPE'
super().__init__(name) super().__init__(name)
if special_tokens is not None:
# special_tokens: "[SEP],[PAD]"
special_tokens = special_tokens.split(",")
else:
special_tokens = []
self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace',
special_tokens=special_tokens, max_len=None) special_tokens=[], max_len=None)
self.eod_id = self.tokenizer.encoder['<|endoftext|>'] self.eod_id = self.tokenizer.encoder['<|endoftext|>']
if special_tokens is not None and len(special_tokens) > 0:
if "[SEP]" in special_tokens:
self.sep_id = self.tokenizer.special_tokens['[SEP]']
if "[PAD]" in special_tokens:
self.pad_id = self.tokenizer.special_tokens['[PAD]']
@property @property
def vocab_size(self): def vocab_size(self):
return len(self.tokenizer.encoder) return len(self.tokenizer.encoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment