Commit 8b1da95a authored by Neel Kant's avatar Neel Kant
Browse files

Build simple mlm examples

parent 81c71789
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import collections import collections
import itertools
import numpy as np import numpy as np
...@@ -80,6 +82,33 @@ def build_training_sample(sample, ...@@ -80,6 +82,33 @@ def build_training_sample(sample,
return train_sample return train_sample
def build_simple_training_sample(sample, target_seq_length, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
tokens, tokentypes = create_single_tokens_and_tokentypes(tokens)
max_predictions_per_seq = masked_lm_prob * max_seq_length
(tokens, masked_positions, masked_labels, _) = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np}
return train_sample
def get_a_and_b_segments(sample, np_rng): def get_a_and_b_segments(sample, np_rng):
"""Divide sample into a and b segments.""" """Divide sample into a and b segments."""
...@@ -132,6 +161,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng): ...@@ -132,6 +161,7 @@ def truncate_segments(tokens_a, tokens_b, len_a, len_b, max_num_tokens, np_rng):
tokens.pop() tokens.pop()
return True return True
def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
"""Merge segments A and B, add [CLS] and [SEP] and build tokentypes.""" """Merge segments A and B, add [CLS] and [SEP] and build tokentypes."""
...@@ -158,6 +188,15 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): ...@@ -158,6 +188,15 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id):
return tokens, tokentypes return tokens, tokentypes
def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id):
tokens = []
tokens.append(cls_id)
tokens.extend(list(_tokens))
tokens.append(sep_id)
tokentypes = [0] * len(tokens)
return tokens, tokentypes
MaskedLmInstance = collections.namedtuple("MaskedLmInstance", MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"]) ["index", "label"])
......
import itertools import itertools
import random import random
import os import os
import sys
import time import time
import numpy as np import numpy as np
...@@ -27,14 +26,8 @@ class InverseClozeDataset(Dataset): ...@@ -27,14 +26,8 @@ class InverseClozeDataset(Dataset):
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
self.samples_mapping = get_samples_mapping(self.context_dataset, self.samples_mapping = self.get_samples_mapping(
self.titles_dataset, data_prefix, num_epochs, max_num_samples)
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length,
self.seed,
self.name)
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = tokenizer.inv_vocab self.vocab_id_to_token_list = tokenizer.inv_vocab
...@@ -97,15 +90,7 @@ class InverseClozeDataset(Dataset): ...@@ -97,15 +90,7 @@ class InverseClozeDataset(Dataset):
token_types = [0] * self.max_seq_length token_types = [0] * self.max_seq_length
return tokens, token_types, pad_mask return tokens, token_types, pad_mask
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
def get_samples_mapping(context_dataset,
titles_dataset,
data_prefix,
num_epochs,
max_num_samples,
max_seq_length,
seed,
name):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -116,13 +101,13 @@ def get_samples_mapping(context_dataset, ...@@ -116,13 +101,13 @@ def get_samples_mapping(context_dataset,
# Filename of the index mapping # Filename of the index mapping
indexmap_filename = data_prefix indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name) indexmap_filename += '_{}_indexmap'.format(self.name)
if num_epochs != (np.iinfo(np.int32).max - 1): if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs) indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1): if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples) indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length) indexmap_filename += '_{}msl'.format(self.max_seq_length)
indexmap_filename += '_{}s'.format(seed) indexmap_filename += '_{}s'.format(self.seed)
indexmap_filename += '.npy' indexmap_filename += '.npy'
# Build the indexed mapping if not exist. # Build the indexed mapping if not exist.
...@@ -132,22 +117,22 @@ def get_samples_mapping(context_dataset, ...@@ -132,22 +117,22 @@ def get_samples_mapping(context_dataset,
'the indices on rank 0 ...'.format(indexmap_filename)) 'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types. # Make sure the types match the helpers input types.
assert context_dataset.doc_idx.dtype == np.int64 assert self.context_dataset.doc_idx.dtype == np.int64
assert context_dataset.sizes.dtype == np.int32 assert self.context_dataset.sizes.dtype == np.int32
# Build samples mapping # Build samples mapping
verbose = torch.distributed.get_rank() == 0 verbose = torch.distributed.get_rank() == 0
start_time = time.time() start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format( print_rank_0(' > building samples index mapping for {} ...'.format(
name)) self.name))
samples_mapping = helpers.build_blocks_mapping( samples_mapping = helpers.build_blocks_mapping(
context_dataset.doc_idx, self.context_dataset.doc_idx,
context_dataset.sizes, self.context_dataset.sizes,
titles_dataset.sizes, self.titles_dataset.sizes,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
max_seq_length-3, # account for added tokens self.max_seq_length-3, # account for added tokens
seed, self.seed,
verbose) verbose)
print_rank_0(' > done building samples index mapping') print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, samples_mapping, allow_pickle=True)
......
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.bert_dataset import get_samples_mapping_
from megatron.data.dataset_utils import build_simple_training_sample
class RealmDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
# Params to store.
self.name = name
self.seed = seed
self.masked_lm_prob = masked_lm_prob
self.max_seq_length = max_seq_length
# Dataset.
self.indexed_dataset = indexed_dataset
# Build the samples mapping.
self.samples_mapping = get_samples_mapping_(self.indexed_dataset,
data_prefix,
num_epochs,
max_num_samples,
self.max_seq_length,
short_seq_prob,
self.seed,
self.name)
# Vocab stuff.
tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = tokenizer.inv_vocab
self.cls_id = tokenizer.cls
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, seq_length = self.samples_mapping[idx]
sample = [self.indexed_dataset[i] for i in range(start_idx, end_idx)]
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return build_simple_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng)
...@@ -214,10 +214,14 @@ class BertModel(MegatronModule): ...@@ -214,10 +214,14 @@ class BertModel(MegatronModule):
state_dict[self._ict_head_key], strict=strict) state_dict[self._ict_head_key], strict=strict)
# REALMBertModel is just BertModel without binary head.
# needs a different kind of dataset though
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
def __init__(self, def __init__(self,
ict_head_size, ict_head_size,
num_tokentypes=0, num_tokentypes=2,
parallel_output=True): parallel_output=True):
super(ICTBertModel, self).__init__() super(ICTBertModel, self).__init__()
bert_args = dict( bert_args = dict(
......
...@@ -178,6 +178,14 @@ class FullTokenizer(object): ...@@ -178,6 +178,14 @@ class FullTokenizer(object):
return convert_by_vocab(self.vocab, tokens) return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids): def convert_ids_to_tokens(self, ids):
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return convert_by_vocab(self.inv_vocab, ids) return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self): def vocab_size(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment