Commit dedb2ef7 authored by Mohammad's avatar Mohammad
Browse files

removed building tokenizer from bert dataset

parent 1788c910
......@@ -22,24 +22,19 @@ import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import mpu
from megatron.data import helpers
from megatron.tokenizer.bert_tokenization import FullTokenizer as FullBertTokenizer
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0
def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
splits_string, train_valid_test_num_samples,
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup):
# Tokenizer is the same
tokenizer = FullBertTokenizer(vocab_file, do_lower_case=True)
print_rank_0(' > using full BERT tokenizer with vocabulary size: {}'.format(
tokenizer.vocab_size()))
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
......@@ -82,7 +77,6 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
dataset = BertDataset(
name=name,
indexed_dataset=indexed_dataset,
tokenizer=tokenizer,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
......@@ -107,7 +101,7 @@ def build_train_valid_test_datasets(vocab_file, data_prefix, data_impl,
class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, tokenizer, data_prefix,
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
......@@ -117,8 +111,7 @@ class BertDataset(Dataset):
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
# Tokenizer and dataset.
self.tokenizer = tokenizer
# Dataset.
self.indexed_dataset = indexed_dataset
......@@ -133,16 +126,13 @@ class BertDataset(Dataset):
self.name)
# Vocab stuff.
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.mask_id = self.tokenizer.vocab['[MASK]']
self.pad_id = self.tokenizer.vocab['[PAD]']
def num_tokens(self):
return self.tokenizer.vocab_size()
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
......
......@@ -75,6 +75,18 @@ class AbstractTokenizer(ABC):
def vocab_size(self):
pass
@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass
@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass
@abstractmethod
def tokenize(self, text):
pass
......@@ -99,6 +111,11 @@ class AbstractTokenizer(ABC):
raise NotImplementedError('EOD is not provided for {} '
'tokenizer'.format(self.name))
@property
def mask(self):
raise NotImplementedError('MASK is not provided for {} '
'tokenizer'.format(self.name))
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
......@@ -113,11 +130,20 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
self.cls_id = self.tokenizer.vocab['[CLS]']
self.sep_id = self.tokenizer.vocab['[SEP]']
self.pad_id = self.tokenizer.vocab['[PAD]']
self.mask_id = self.tokenizer.vocab['[MASK]']
@property
def vocab_size(self):
return self.tokenizer.vocab_size()
@property
def vocab(self):
return self.tokenizer.vocab
@property
def inv_vocab(self):
return self.tokenizer.inv_vocab
def tokenize(self, text):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
......@@ -134,6 +160,9 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
def pad(self):
return self.pad_id
@property
def mask(self):
return self.mask_id
class _GPT2BPETokenizer(AbstractTokenizer):
"""Original GPT2 BPE tokenizer."""
......@@ -150,6 +179,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
def vocab_size(self):
return len(self.tokenizer.encoder)
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
......
......@@ -137,7 +137,6 @@ def get_train_val_test_data():
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
vocab_file=args.vocab_file,
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment