Commit 45f4ee54 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

yttm + BytelevelBPE + setencepeice tokenizer support

parent 189e72a7
...@@ -850,8 +850,13 @@ def _add_data_args(parser): ...@@ -850,8 +850,13 @@ def _add_data_args(parser):
default=None, default=None,
choices=['BertWordPieceLowerCase', choices=['BertWordPieceLowerCase',
'BertWordPieceCase', 'BertWordPieceCase',
'GPT2BPETokenizer'], 'GPT2BPETokenizer',
'YTTMTokenizer',
'ByteLevelBPETokenizer',
'SentencePieceTokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--data-impl', type=str, default='infer', group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'], choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.') help='Implementation of indexed datasets.')
......
...@@ -20,6 +20,9 @@ from abc import abstractmethod ...@@ -20,6 +20,9 @@ from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer from .gpt2_tokenization import GPT2Tokenizer
import sentencepiece
import tokenizers
import youtokentome as yttm
def build_tokenizer(args): def build_tokenizer(args):
...@@ -41,6 +44,16 @@ def build_tokenizer(args): ...@@ -41,6 +44,16 @@ def build_tokenizer(args):
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'YTTMTokenizer':
assert args.tokenizer_model is not None
tokenizer = _YTTMTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'ByteLevelBPETokenizer':
assert args.vocab_file is not None
assert args.merge_file is not None
tokenizer = _ByteLevelBPETokenizer(args.vocab_file, args.merge_file, vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'SentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
else: else:
raise NotImplementedError('{} tokenizer is not ' raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type)) 'implemented.'.format(args.tokenizer_type))
...@@ -289,3 +302,356 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -289,3 +302,356 @@ class _GPT2BPETokenizer(AbstractTokenizer):
@property @property
def eod(self): def eod(self):
return self.eod_id return self.eod_id
class _YTTMTokenizer(AbstractTokenizer):
""" YTTM tokenizer."""
def __init__(self, model_path, vocab_extra_ids=0):
name = 'YTTM'
super().__init__(name)
self.bpe = yttm.BPE(model=model_path)
self.vocab_ = {}
self.inv_vocab_ = {}
self._additional_special_tokens = []
self._initalize(vocab_extra_ids)
def _initalize(self, vocab_extra_ids):
for subword in self.bpe.vocab():
self.add_token(subword)
self.add_token('<CLS>'); self.cls_id = self.vocab_['<CLS>']
self.add_token('<SEP>'); self.sep_id = self.vocab_['<SEP>']
self.add_token('<PAD>'); self.pad_id = self.vocab_['<PAD>']
self.add_token('<BOS>'); self.bos_id = self.vocab_['<BOS>']
self.add_token('<EOS>'); self.eos_id = self.vocab_['<EOS>']
self.add_token('<EOD>'); self.eod_id = self.vocab_['<EOD>']
self.add_token('<MASK>'); self.mask_id = self.vocab_['<MASK>']
self.special_token_ids = [self.cls_id, self.sep_id, self.pad_id,
self.bos_id, self.eos_id, self.eod_id,
self.mask_id]
self.add_additional_special_tokens([
"<extra_id_{}>".format(i) for i in range(vocab_extra_ids)
])
def add_token(self, token):
if token not in self.vocab:
self.inv_vocab[self.vocab_size] = token
self.vocab[token] = self.vocab_size
def add_additional_special_tokens(self, tokens):
for token in tokens:
if token not in self.vocab:
self._additional_special_tokens.append(token)
self.special_token_ids.append(token)
self.add_token(token)
@property
def vocab_size(self):
return len(self.vocab_)
@property
def vocab(self):
return self.vocab_
@property
def inv_vocab(self):
return self.inv_vocab_
def tokenize(self, text):
return self.bpe.encode([text], output_type=yttm.OutputType.ID)[0]
def detokenize(self, token_ids):
return self.bpe.decode([token_ids], ignore_ids=self.special_token_ids)[0]
@property
def cls(self):
return self.cls_id
@property
def sep(self):
return self.sep_id
@property
def pad(self):
return self.pad_id
@property
def bos_token_id(self):
return self.bos_id
@property
def bos(self):
return self.bos_id
@property
def eod(self):
return self.eod_id
@property
def eos_token_id(self):
return self.eos_id
@property
def eos(self):
return self.eos_id
@property
def mask(self):
return self.mask_id
@property
def additional_special_tokens_ids(self):
return [self.vocab.get(token) for token in self._additional_special_tokens]
class _ByteLevelBPETokenizer(AbstractTokenizer):
"""ByteLevelBPETokenizer that can support T5 pretraining."""
def __init__(self, vocab_file, merges_file, vocab_extra_ids=0):
name = 'ByteLevelBPETokenizer'
super().__init__(name)
self._bpe = tokenizers.ByteLevelBPETokenizer(vocab=vocab_file, merges=merges_file)
self._inv_vocab = {}
self._additional_special_tokens = []
self._initalize(vocab_extra_ids)
def _initalize(self, vocab_extra_ids):
self._bpe.add_special_tokens(['<CLS>', '<SEP>', '<PAD>', '<BOS>', '<EOS>', '<EOD>', '<MASK>'])
self._cls_id = self.vocab['<CLS>']
self._sep_id = self.vocab['<SEP>']
self._pad_id = self.vocab['<PAD>']
self._bos_id = self.vocab['<BOS>']
self._eos_id = self.vocab['<EOS>']
self._eod_id = self.vocab['<EOD>']
self._mask_id = self.vocab['<MASK>']
t5_tokens = ["<extra_id_{}>".format(i) for i in range(vocab_extra_ids)]
self._bpe.add_special_tokens(t5_tokens)
self._additional_special_tokens = t5_tokens
@property
def vocab_size(self):
return self._bpe.get_vocab_size()
@property
def vocab(self):
return self._bpe.get_vocab()
@property
def inv_vocab(self):
vocab = self.vocab
if len(self._inv_vocab) != len(vocab):
self._inv_vocab = {}
for (k, v) in vocab.items():
self._inv_vocab[v] = k
return self._inv_vocab
def tokenize(self, text):
return self._bpe.encode(text).ids
def detokenize(self, token_ids):
return self._bpe.decode(token_ids)
@property
def cls(self):
return self._cls_id
@property
def sep(self):
return self._sep_id
@property
def pad(self):
return self._pad_id
@property
def bos_token_id(self):
return self._bos_id
@property
def bos(self):
return self._bos_id
@property
def eod(self):
return self._eod_id
@property
def eos_token_id(self):
return self._eos_id
@property
def eos(self):
return self._eos_id
@property
def mask(self):
return self._mask_id
@property
def additional_special_tokens_ids(self):
return [self.vocab.get(token) for token in self._additional_special_tokens]
class _SentencePieceTokenizer(AbstractTokenizer):
"""SentencePieceTokenizer-Megatron wrapper"""
def __init__(self, model_file, vocab_extra_ids=0):
name = 'SentencePieceTokenizer'
super().__init__(name)
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._initalize(vocab_extra_ids)
def _initalize(self, vocab_extra_ids):
self._vocab = {}
self._inv_vocab = {}
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for i in range(len(self._tokenizer)):
t = self._tokenizer.id_to_piece(i)
self._inv_vocab[i] = t
self._vocab[t] = i
def _add_special_token(t):
if t not in self._vocab:
next_id = len(self._vocab)
self._vocab[t] = next_id
self._inv_vocab[next_id] = t
self._special_tokens[t] = self._vocab[t]
self._inv_special_tokens[self._vocab[t]] = t
_add_special_token('<CLS>'); self._cls_id = self._vocab['<CLS>']
_add_special_token('<SEP>'); self._sep_id = self._vocab['<SEP>']
_add_special_token('<EOD>'); self._eod_id = self._vocab['<EOD>']
_add_special_token('<MASK>'); self._mask_id = self._vocab['<MASK>']
pad_id = self._tokenizer.pad_id()
try:
pad_token = self._tokenizer.id_to_piece(pad_id)
except IndexError:
pad_token = '<PAD>'
_add_special_token(pad_token); self._pad_id = self._vocab[pad_token]
bos_id = self._tokenizer.bos_id()
try:
bos_token = self._tokenizer.id_to_piece(bos_id)
except IndexError:
bos_token = '<BOS>'
_add_special_token(bos_token); self._bos_id = self._vocab[bos_token]
eos_id = self._tokenizer.eos_id()
try:
eos_token = self._tokenizer.id_to_piece(eos_id)
except IndexError:
eos_token = '<EOS>'
_add_special_token(eos_token); self._eos_id = self._vocab[eos_token]
for i in range(vocab_extra_ids):
t = "<extra_id_{}>".format(i)
_add_special_token(t)
self._t5_tokens += [t]
@property
def vocab_size(self):
return len(self._vocab)
@property
def vocab(self):
return self._vocab
@property
def inv_vocab(self):
return self._inv_vocab
# From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
def tokenize(self, text):
ids = []
idx = 0
last_idx = 0
while 1:
indices = {}
for token in self._special_tokens:
try:
indices[token] = text[idx:].index(token)
except ValueError:
continue
if len(indices) == 0:
break
next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token]
ids.extend(self._tokenizer.encode_as_ids(text[idx:next_idx]))
ids.append(self._special_tokens[next_token])
idx = next_idx + len(next_token)
ids.extend(self._tokenizer.encode_as_ids(text[idx:]))
return ids
# From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L125
def detokenize(self, ids):
text = ""
last_i = 0
for i, id in enumerate(ids):
if id in self._inv_special_tokens:
text += self._tokenizer.decode_ids(ids[last_i:i]) + " "
text += self._inv_special_tokens[id] + " "
last_i = i + 1
text += self._tokenizer.decode_ids(ids[last_i:])
return text.strip()
@property
def cls(self):
return self._cls_id
@property
def sep(self):
return self._sep_id
@property
def pad(self):
return self._pad_id
@property
def bos_token_id(self):
return self._bos_id
@property
def bos(self):
return self._bos_id
@property
def eod(self):
return self._eod_id
@property
def eos_token_id(self):
return self._eos_id
@property
def eos(self):
return self._eos_id
@property
def mask(self):
return self._mask_id
@property
def additional_special_tokens_ids(self):
return [self.vocab[k] for k in self._t5_tokens]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment