Commit dedb2ef7 authored by Mohammad's avatar Mohammad
Browse files

removed building tokenizer from bert dataset

parent 1788c910
...@@ -22,24 +22,19 @@ import numpy as np ...@@ -22,24 +22,19 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data import helpers from megatron.data import helpers
from megatron.tokenizer.bert_tokenization import FullTokenizer as FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0 from megatron import print_rank_0
def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
splits_string, train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup): short_seq_prob, seed, skip_warmup):
# Tokenizer is the same
tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
tokenizer.vocab_size()))
# Indexed dataset. # Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl, data_impl,
...@@ -82,7 +77,6 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, ...@@ -82,7 +77,6 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
dataset = BertDataset( dataset = BertDataset(
name=name, name=name,
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
data_prefix=data_prefix, data_prefix=data_prefix,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
...@@ -107,7 +101,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl, ...@@ -107,7 +101,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
class BertDataset(Dataset): class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, tokenizer, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed):
...@@ -117,8 +111,7 @@ class BertDataset(Dataset): ...@@ -117,8 +111,7 @@ class BertDataset(Dataset):
self.masked_lm_prob = masked_lm_prob self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
# Tokenizer and dataset. # Dataset.
self.tokenizer = tokenizer
self.indexed_dataset = indexed_dataset self.indexed_dataset = indexed_dataset
...@@ -133,16 +126,13 @@ class BertDataset(Dataset): ...@@ -133,16 +126,13 @@ class BertDataset(Dataset):
self.name) self.name)
# Vocab stuff. # Vocab stuff.
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) tokenizer = get_tokenizer()
self.vocab_id_to_token_dict = self.tokenizer.inv_vocab self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.cls_id = self.tokenizer.vocab['[CLS]'] self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.sep_id = self.tokenizer.vocab['[SEP]'] self.cls_id = tokenizer.cls
self.mask_id = self.tokenizer.vocab['[MASK]'] self.sep_id = tokenizer.sep
self.pad_id = self.tokenizer.vocab['[PAD]'] self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def num_tokens(self):
return self.tokenizer.vocab_size()
def __len__(self): def __len__(self):
......
...@@ -75,6 +75,18 @@ class AbstractTokenizer(ABC): ...@@ -75,6 +75,18 @@ class AbstractTokenizer(ABC):
def vocab_size(self): def vocab_size(self):
pass pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod @abstractmethod
def tokenize(self, text): def tokenize(self, text):
pass pass
...@@ -99,6 +111,11 @@ class AbstractTokenizer(ABC): ...@@ -99,6 +111,11 @@ class AbstractTokenizer(ABC):
raise NotImplementedError('EOD is not provided for {} ' raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name)) 'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer): class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer.""" """Original BERT wordpiece tokenizer."""
...@@ -113,11 +130,20 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -113,11 +130,20 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
self.cls_id = self.tokenizer.vocab['[CLS]'] self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]'] self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]'] self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
@property @property
def vocab_size(self): def vocab_size(self):
return self.tokenizer.vocab_size() return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text): def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text) text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens) return self.tokenizer.convert_tokens_to_ids(text_tokens)
...@@ -134,6 +160,9 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -134,6 +160,9 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def pad(self): def pad(self):
return self.pad_id return self.pad_id
@property
def mask(self):
return self.mask_id
class _GPT2BPETokenizer(AbstractTokenizer): class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer.""" """Original GPT2 BPE tokenizer."""
...@@ -150,6 +179,14 @@ class _GPT2BPETokenizer(AbstractTokenizer): ...@@ -150,6 +179,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def vocab_size(self): def vocab_size(self):
return len(self.tokenizer.encoder) return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text): def tokenize(self, text):
return self.tokenizer.encode(text) return self.tokenizer.encode(text)
......
...@@ -137,7 +137,6 @@ def get_train_val_test_data(): ...@@ -137,7 +137,6 @@ def get_train_val_test_data():
print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
vocab_file=args.vocab_file,
data_prefix=args.data_path, data_prefix=args.data_path,
data_impl=args.data_impl, data_impl=args.data_impl,
splits_string=args.split, splits_string=args.split,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment