Commit 717c5274 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'nextlm-merge-sentencepiece' into 'main'

Add support for GPTSentencePieceTokenizer and related fixes.

See merge request ADLR/megatron-lm!561
parents 018391a6 3207c19a
...@@ -1066,7 +1066,8 @@ def _add_data_args(parser): ...@@ -1066,7 +1066,8 @@ def _add_data_args(parser):
choices=['BertWordPieceLowerCase', choices=['BertWordPieceLowerCase',
'BertWordPieceCase', 'BertWordPieceCase',
'GPT2BPETokenizer', 'GPT2BPETokenizer',
'SentencePieceTokenizer'], 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None, group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.') help='Sentencepiece tokenizer model.')
......
...@@ -89,7 +89,7 @@ def set_global_variables(args): ...@@ -89,7 +89,7 @@ def set_global_variables(args):
set_args(args) set_args(args)
_build_num_microbatches_calculator(args) _build_num_microbatches_calculator(args)
if args.vocab_file: if args.vocab_file or args.tokenizer_model:
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import torch import torch
from megatron import get_tokenizer from megatron import get_tokenizer, get_args
from .communication import broadcast_int_list, broadcast_tensor from .communication import broadcast_int_list, broadcast_tensor
...@@ -16,7 +16,7 @@ def detokenize_generations(tokens_gpu_tensor, ...@@ -16,7 +16,7 @@ def detokenize_generations(tokens_gpu_tensor,
"""Detokenize the generated tokens.""" """Detokenize the generated tokens."""
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
args = get_args()
prompts_plus_generations = [] prompts_plus_generations = []
if return_segments: if return_segments:
prompts_plus_generations_segments = [] prompts_plus_generations_segments = []
...@@ -30,10 +30,13 @@ def detokenize_generations(tokens_gpu_tensor, ...@@ -30,10 +30,13 @@ def detokenize_generations(tokens_gpu_tensor,
if return_segments: if return_segments:
words = [] words = []
for token in sequence_tokens: for token in sequence_tokens:
word = tokenizer.tokenizer.decoder[token] if args.tokenizer_type in ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer']:
word = bytearray( word = tokenizer.decoder[token]
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode( else:
'utf-8', errors='replace') word = tokenizer.tokenizer.decoder[token]
word = bytearray(
[tokenizer.tokenizer.byte_decoder[c] for c in word]).decode(
'utf-8', errors='replace')
words.append(word) words.append(word)
prompts_plus_generations_segments.append(words) prompts_plus_generations_segments.append(words)
......
...@@ -15,7 +15,7 @@ def build_tokenizer(args): ...@@ -15,7 +15,7 @@ def build_tokenizer(args):
print('> building {} tokenizer ...'.format(args.tokenizer_type), print('> building {} tokenizer ...'.format(args.tokenizer_type),
flush=True) flush=True)
if args.tokenizer_type != 'SentencePieceTokenizer': if args.tokenizer_type not in ['SentencePieceTokenizer', 'GPTSentencePieceTokenizer']:
assert args.vocab_file is not None assert args.vocab_file is not None
# Select and instantiate the tokenizer. # Select and instantiate the tokenizer.
...@@ -33,6 +33,9 @@ def build_tokenizer(args): ...@@ -33,6 +33,9 @@ def build_tokenizer(args):
elif args.tokenizer_type == 'SentencePieceTokenizer': elif args.tokenizer_type == 'SentencePieceTokenizer':
assert args.tokenizer_model is not None assert args.tokenizer_model is not None
tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids) tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids)
elif args.tokenizer_type == 'GPTSentencePieceTokenizer':
assert args.tokenizer_model is not None
tokenizer = _GPTSentencePieceTokenizer(args.tokenizer_model)
else: else:
raise NotImplementedError('{} tokenizer is not ' raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type)) 'implemented.'.format(args.tokenizer_type))
...@@ -291,23 +294,25 @@ class _SentencePieceTokenizer(AbstractTokenizer): ...@@ -291,23 +294,25 @@ class _SentencePieceTokenizer(AbstractTokenizer):
super().__init__(name) super().__init__(name)
import sentencepiece import sentencepiece
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file) self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
self._initalize(vocab_extra_ids) self._initalize(vocab_extra_ids)
def _initalize(self, vocab_extra_ids): def _populate_vocab(self):
self._vocab = {} self._vocab = {}
self._inv_vocab = {} self._inv_vocab = {}
for i in range(len(self.tokenizer)):
t = self.tokenizer.id_to_piece(i)
self._inv_vocab[i] = t
self._vocab[t] = i
def _initalize(self, vocab_extra_ids):
self._populate_vocab()
self._special_tokens = {} self._special_tokens = {}
self._inv_special_tokens = {} self._inv_special_tokens = {}
self._t5_tokens = [] self._t5_tokens = []
for i in range(len(self._tokenizer)):
t = self._tokenizer.id_to_piece(i)
self._inv_vocab[i] = t
self._vocab[t] = i
def _add_special_token(t): def _add_special_token(t):
if t not in self._vocab: if t not in self._vocab:
next_id = len(self._vocab) next_id = len(self._vocab)
...@@ -325,25 +330,25 @@ class _SentencePieceTokenizer(AbstractTokenizer): ...@@ -325,25 +330,25 @@ class _SentencePieceTokenizer(AbstractTokenizer):
_add_special_token('<MASK>') _add_special_token('<MASK>')
self._mask_id = self._vocab['<MASK>'] self._mask_id = self._vocab['<MASK>']
pad_id = self._tokenizer.pad_id() pad_id = self.tokenizer.pad_id()
try: try:
pad_token = self._tokenizer.id_to_piece(pad_id) pad_token = self.tokenizer.id_to_piece(pad_id)
except IndexError: except IndexError:
pad_token = '<PAD>' pad_token = '<PAD>'
_add_special_token(pad_token) _add_special_token(pad_token)
self._pad_id = self._vocab[pad_token] self._pad_id = self._vocab[pad_token]
bos_id = self._tokenizer.bos_id() bos_id = self.tokenizer.bos_id()
try: try:
bos_token = self._tokenizer.id_to_piece(bos_id) bos_token = self.tokenizer.id_to_piece(bos_id)
except IndexError: except IndexError:
bos_token = '<BOS>' bos_token = '<BOS>'
_add_special_token(bos_token) _add_special_token(bos_token)
self._bos_id = self._vocab[bos_token] self._bos_id = self._vocab[bos_token]
eos_id = self._tokenizer.eos_id() eos_id = self.tokenizer.eos_id()
try: try:
eos_token = self._tokenizer.id_to_piece(eos_id) eos_token = self.tokenizer.id_to_piece(eos_id)
except IndexError: except IndexError:
eos_token = '<EOS>' eos_token = '<EOS>'
_add_special_token(eos_token) _add_special_token(eos_token)
...@@ -366,6 +371,14 @@ class _SentencePieceTokenizer(AbstractTokenizer): ...@@ -366,6 +371,14 @@ class _SentencePieceTokenizer(AbstractTokenizer):
def inv_vocab(self): def inv_vocab(self):
return self._inv_vocab return self._inv_vocab
@property
def decoder(self):
return self._inv_vocab
@property
def encoder(self):
return self._vocab
# From: # From:
# https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89 # https://github.com/NVIDIA/NeMo/blob/c8fa217e811d60d11d014827c7f3845ff6c99ae7/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py#L89
def tokenize(self, text): def tokenize(self, text):
...@@ -385,11 +398,11 @@ class _SentencePieceTokenizer(AbstractTokenizer): ...@@ -385,11 +398,11 @@ class _SentencePieceTokenizer(AbstractTokenizer):
next_token = min(indices, key=indices.get) next_token = min(indices, key=indices.get)
next_idx = idx + indices[next_token] next_idx = idx + indices[next_token]
ids.extend(self._tokenizer.encode_as_ids(text[idx:next_idx])) ids.extend(self.tokenizer.encode_as_ids(text[idx:next_idx]))
ids.append(self._special_tokens[next_token]) ids.append(self._special_tokens[next_token])
idx = next_idx + len(next_token) idx = next_idx + len(next_token)
ids.extend(self._tokenizer.encode_as_ids(text[idx:])) ids.extend(self.tokenizer.encode_as_ids(text[idx:]))
return ids return ids
# From: # From:
...@@ -400,12 +413,12 @@ class _SentencePieceTokenizer(AbstractTokenizer): ...@@ -400,12 +413,12 @@ class _SentencePieceTokenizer(AbstractTokenizer):
for i, id in enumerate(ids): for i, id in enumerate(ids):
if id in self._inv_special_tokens: if id in self._inv_special_tokens:
text += self._tokenizer.decode_ids(ids[last_i:i]) + " " text += self.tokenizer.decode_ids(ids[last_i:i]) + " "
text += self._inv_special_tokens[id] + " " text += self._inv_special_tokens[id] + " "
last_i = i + 1 last_i = i + 1
text += self._tokenizer.decode_ids(ids[last_i:]) text += self.tokenizer.decode_ids(ids[last_i:])
return text.strip() return text
@property @property
def cls(self): def cls(self):
...@@ -447,3 +460,42 @@ class _SentencePieceTokenizer(AbstractTokenizer): ...@@ -447,3 +460,42 @@ class _SentencePieceTokenizer(AbstractTokenizer):
def additional_special_tokens_ids(self): def additional_special_tokens_ids(self):
return [self.vocab[k] for k in self._t5_tokens] return [self.vocab[k] for k in self._t5_tokens]
class _GPTSentencePieceTokenizer(_SentencePieceTokenizer):
"""SentencePieceTokenizer-Megatron wrapper"""
def __init__(self, model_file,):
super().__init__(model_file, vocab_extra_ids=0)
def _initalize(self, vocab_extra_ids):
self._populate_vocab()
self._pad_id = self.tokenizer.pad_id()
self._bos_id = self.tokenizer.bos_id()
self._eos_id = self.tokenizer.eos_id()
def tokenize(self, text):
return self.tokenizer.encode_as_ids(text)
def detokenize(self, ids):
return self.tokenizer.decode_ids(ids)
@property
def cls(self):
return -1
@property
def sep(self):
return -1
@property
def mask(self):
return -1
@property
def eod(self):
return self._eos_id
@property
def additional_special_tokens_ids(self):
return None
...@@ -94,7 +94,7 @@ def get_args(): ...@@ -94,7 +94,7 @@ def get_args():
group = parser.add_argument_group(title='tokenizer') group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True, group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase', choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer'], 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None, group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file') help='Path to the vocab file')
...@@ -104,6 +104,8 @@ def get_args(): ...@@ -104,6 +104,8 @@ def get_args():
help='Append an <eod> token to the end of a document.') help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english', group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.') help='Language to use for NLTK-powered sentence splitting.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='sentencepeice tokenizer model.')
group = parser.add_argument_group(title='output data') group = parser.add_argument_group(title='output data')
......
...@@ -192,7 +192,7 @@ def get_args(): ...@@ -192,7 +192,7 @@ def get_args():
group = parser.add_argument_group(title='tokenizer') group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True, group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase', choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer'], 'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None, group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.') help='YTTM tokenizer model.')
...@@ -326,6 +326,9 @@ def main(): ...@@ -326,6 +326,9 @@ def main():
for p in processes: for p in processes:
p.join() p.join()
if args.partitions == 1:
return
# encode partition files in parallel # encode partition files in parallel
processes = [] processes = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment