Commit 46ffb75c authored by Jared Casper's avatar Jared Casper
Browse files

Add support for GPTSentencePieceTokenizer and related fixes.

parent 4e891fe9
......@@ -1064,7 +1064,8 @@ def _add_data_args(parser):
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer',
'SentencePieceTokenizer'],
'SentencePieceTokenizer',
'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
......
......@@ -6,7 +6,7 @@
import torch
from megatron import get_tokenizer
from megatron import get_tokenizer, get_args
from .communication import broadcast_int_list, broadcast_tensor
......@@ -16,7 +16,7 @@ def detokenize_generations(tokens_gpu_tensor,
"""Detokenize the generated tokens."""
tokenizer = get_tokenizer()
args = get_args()
prompts_plus_generations = []
if return_segments:
prompts_plus_generations_segments = []
......@@ -30,6 +30,9 @@ def detokenize_generations(tokens_gpu_tensor,
if return_segments:
words = []
for token in sequence_tokens:
if args.tokenizer_type in ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer']:
word = tokenizer.decoder[token]
else:
word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
......
......@@ -15,7 +15,7 @@ def build_tokenizer(args):
print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True)
if args.tokenizer_type != 'SentencePieceTokenizer':
if args.tokenizer_type not in ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer']:
assert args.vocab_file is not None
# Select and instantiate the tokenizer.
......@@ -33,6 +33,9 @@ def build_tokenizer(args):
elif args.tokenizer_type == 'SentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPTSentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
......@@ -291,23 +294,25 @@ class _SentencePieceTokenizer(AbstractTokenizer):
super().__init__(name)
import sentencepiece
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._initalize(vocab_extra_ids)
def _initalize(self, vocab_extra_ids):
def _populate_vocab(self):
self._vocab = {}
self._inv_vocab = {}
for i in range(len(self.tokenizer)):
t = self.tokenizer.id_to_piece(i)
self._inv_vocab[i] = t
self._vocab[t] = i
def _initalize(self, vocab_extra_ids):
self._populate_vocab()
self._special_tokens = {}
self._inv_special_tokens = {}
self._t5_tokens = []
for i in range(len(self._tokenizer)):
t = self._tokenizer.id_to_piece(i)
self._inv_vocab[i] = t
self._vocab[t] = i
def _add_special_token(t):
if t not in self._vocab:
next_id = len(self._vocab)
......@@ -325,25 +330,25 @@ class _SentencePieceTokenizer(AbstractTokenizer):
_add_special_token('<MASK>')
self._mask_id = self._vocab['<MASK>']
pad_id = self._tokenizer.pad_id()
pad_id = self.tokenizer.pad_id()
try:
pad_token = self._tokenizer.id_to_piece(pad_id)
pad_token = self.tokenizer.id_to_piece(pad_id)
except IndexError:
pad_token = '<PAD>'
_add_special_token(pad_token)
self._pad_id = self._vocab[pad_token]
bos_id = self._tokenizer.bos_id()
bos_id = self.tokenizer.bos_id()
try:
bos_token = self._tokenizer.id_to_piece(bos_id)
bos_token = self.tokenizer.id_to_piece(bos_id)
except IndexError:
bos_token = '<BOS>'
_add_special_token(bos_token)
self._bos_id = self._vocab[bos_token]
eos_id = self._tokenizer.eos_id()
eos_id = self.tokenizer.eos_id()
try:
eos_token = self._tokenizer.id_to_piece(eos_id)
eos_token = self.tokenizer.id_to_piece(eos_id)
except IndexError:
eos_token = '<EOS>'
_add_special_token(eos_token)
......@@ -366,6 +371,14 @@ class _SentencePieceTokenizer(AbstractTokenizer):
def inv_vocab(self):
return self._inv_vocab
@property
def decoder(self):
return self._inv_vocab
@property
def encoder(self):
return self._vocab
# From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
def tokenize(self, text):
......@@ -385,11 +398,11 @@ class _SentencePieceTokenizer(AbstractTokenizer):
next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token]
ids.extend(self._tokenizer.encode_as_ids(text[idx:next_idx]))
ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
ids.append(self._special_tokens[next_token])
idx = next_idx + len(next_token)
ids.extend(self._tokenizer.encode_as_ids(text[idx:]))
ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
return ids
# From:
......@@ -400,12 +413,12 @@ class _SentencePieceTokenizer(AbstractTokenizer):
for i, id in enumerate(ids):
if id in self._inv_special_tokens:
text += self._tokenizer.decode_ids(ids[last_i:i]) + " "
text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
text += self._inv_special_tokens[id] + " "
last_i = i + 1
text += self._tokenizer.decode_ids(ids[last_i:])
return text.strip()
text += self.tokenizer.decode_ids(ids[last_i:])
return text
@property
def cls(self):
......@@ -447,3 +460,42 @@ class _SentencePieceTokenizer(AbstractTokenizer):
def additional_special_tokens_ids(self):
return [self.vocab[k] for k in self._t5_tokens]
class _GPTSentencePieceTokenizer(_SentencePieceTokenizer):
"""SentencePieceTokenizer-Megatron wrapper"""
def __init__(self, model_file,):
super().__init__(model_file, vocab_extra_ids=0)
def _initalize(self, vocab_extra_ids):
self._populate_vocab()
self._pad_id = self.tokenizer.pad_id()
self._bos_id = self.tokenizer.bos_id()
self._eos_id = self.tokenizer.eos_id()
def tokenize(self, text):
return self.tokenizer.encode_as_ids(text)
def detokenize(self, ids):
return self.tokenizer.decode_ids(ids)
@property
def cls(self):
return -1
@property
def sep(self):
return -1
@property
def mask(self):
return -1
@property
def eod(self):
return self._eos_id
@property
def additional_special_tokens_ids(self):
return None
......@@ -94,7 +94,7 @@ def get_args():
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer'],
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
......
......@@ -192,7 +192,7 @@ def get_args():
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer'],
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment