Commit 8b1da95a authored by Neel Kant's avatar Neel Kant
Browse files

Build simple mlm examples

parent 81c71789
......@@ -15,6 +15,8 @@
import collections
import itertools
import numpy as np
......@@ -80,6 +82,33 @@ def build_training_sample(sample,
return train_sample
def build_simple_training_sample(sample, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
tokens, tokentypes = create_single_tokens_and_tokentypes(tokens)
max_predictions_per_seq = masked_lm_prob * max_seq_length
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np}
return train_sample
def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments."""
......@@ -132,6 +161,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
tokens.pop()
return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
......@@ -158,6 +188,15 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
return tokens, tokentypes
def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id):
tokens = []
tokens.append(cls_id)
tokens.extend(list(_tokens))
tokens.append(sep_id)
tokentypes = [0] * len(tokens)
return tokens, tokentypes
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
......
import itertools
import random
import os
import sys
import time
import numpy as np
......@@ -27,14 +26,8 @@ class InverseClozeDataset(Dataset):
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = get_samples_mapping(self.context_dataset,
self.titles_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length,
self.seed,
self.name)
self.samples_mapping = self.get_samples_mapping(
data_prefix, num_epochs, max_num_samples)
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = tokenizer.inv_vocab
......@@ -97,82 +90,74 @@ class InverseClozeDataset(Dataset):
token_types = [0] * self.max_seq_length
return tokens, token_types, pad_mask
def get_samples_mapping(context_dataset,
titles_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
seed,
name):
if not num_epochs:
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert context_dataset.doc_idx.dtype == np.int64
assert context_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
samples_mapping = helpers.build_blocks_mapping(
context_dataset.doc_idx,
context_dataset.sizes,
titles_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length-3, # account for added tokens
seed,
verbose)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(self.name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(self.max_seq_length)
indexmap_filename += '_{}s'.format(self.seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert self.context_dataset.doc_idx.dtype == np.int64
assert self.context_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
self.name))
samples_mapping = helpers.build_blocks_mapping(
self.context_dataset.doc_idx,
self.context_dataset.sizes,
self.titles_dataset.sizes,
num_epochs,
max_num_samples,
self.max_seq_length-3, # account for added tokens
self.seed,
verbose)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.bert_dataset import get_samples_mapping_
from megatron.data.dataset_utils import build_simple_training_sample
class RealmDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length,
short_seq_prob,
self.seed,
self.name)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_simple_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng)
......@@ -214,10 +214,14 @@ class BertModel(MegatronModule):
state_dict[self._ict_head_key], strict=strict)
# REALMBertModel is just BertModel without binary head.
# needs a different kind of dataset though
class ICTBertModel(MegatronModule):
def __init__(self,
ict_head_size,
num_tokentypes=0,
num_tokentypes=2,
parallel_output=True):
super(ICTBertModel, self).__init__()
bert_args = dict(
......
......@@ -178,6 +178,14 @@ class FullTokenizer(object):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment