"vscode:/vscode.git/clone" did not exist on "2f85d781496b43bb04b58f2b82a96e5f516bf253"
Commit 4abd7ce2 authored by Neel Kant's avatar Neel Kant
Browse files

Refactor before merge request

parent f1ad8c94
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider
def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_checkpoint()
model.eval()
dataset = get_dataset()
data_iter = iter(get_dataloader(dataset))
all_input_tokens = []
all_input_logits = []
all_block_tokens = []
all_block_logits = []
for i in range(100):
input_tokens, input_types, input_pad_mask, block_tokens, block_token_types, block_pad_mask = get_batch(data_iter)
input_logits, doc_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(doc_logits.detach().cpu().numpy())
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('block_tokens.npy', all_block_tokens)
np.save('doc_logits.npy', all_block_logits)
def load_checkpoint():
args = get_args()
model = get_model(model_provider)
if isinstance(model, torchDDP):
model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.load)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(state_dict['model'])
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def get_dataset():
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
doc_idx_ptr = block_dataset.get_doc_idx()
total_num_documents = block_dataset.doc_idx.shape[0] - 1
block_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
kwargs = dict(
name='full',
context_dataset=block_dataset,
titles_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=None,
max_num_samples=total_num_documents * 3,
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
)
dataset = InverseClozeDataset(**kwargs)
return dataset
def get_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
main()
......@@ -42,9 +42,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
skip_warmup)
if ict_dataset:
titles_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl,
skip_warmup)
title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
......@@ -54,6 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
......@@ -82,7 +83,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Build the dataset accordingly.
kwargs = dict(
name=name,
context_dataset=indexed_dataset,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
......@@ -92,9 +92,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
)
if ict_dataset:
dataset = InverseClozeDataset(titles_dataset=titles_dataset, **kwargs)
dataset = InverseClozeDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
**kwargs
)
else:
dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs)
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
......
......@@ -452,6 +452,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
uint64_t map_index = 0;
int32_t block_id = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
......@@ -514,14 +515,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map.
if (second) {
const auto map_index_0 = 3 * map_index;
const auto map_index_0 = 4 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
}
// Update indices / counters.
++map_index;
++block_id;
prev_start_index = sent_index + 1;
seq_len = 0;
num_sent = 0;
......@@ -529,6 +532,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
block_id = 0;
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
......@@ -538,7 +542,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
}
assert(maps == NULL);
assert(num_samples < 0);
maps = new DocIdx[3*map_index];
maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index);
}
......@@ -550,12 +554,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i;
const auto j0 = 3 * j;
const auto i0 = 4 * i;
const auto j0 = 4 * j;
// Swap values.
swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
}
// Method to deallocate memory.
......@@ -566,8 +571,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Return the numpy array.
const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides
return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer
free_when_done); // numpy array references
......
import itertools
import random
import os
import sys
import time
import numpy as np
......@@ -16,19 +15,19 @@ from megatron.data import helpers
class InverseClozeDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, context_dataset, titles_dataset, data_prefix,
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.context_dataset = context_dataset
self.titles_dataset = titles_dataset
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = get_samples_mapping(self.context_dataset,
self.titles_dataset,
self.samples_mapping = get_samples_mapping(self.block_dataset,
self.title_dataset,
data_prefix,
num_epochs,
max_num_samples,
......@@ -47,38 +46,38 @@ class InverseClozeDataset(Dataset):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, doc_idx = self.samples_mapping[idx]
title = list(self.titles_dataset[int(doc_idx)])
context = [list(self.context_dataset[i]) for i in range(start_idx, end_idx)]
assert len(context) > 1
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.title_dataset[int(doc_idx)])
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1
# avoid selecting the first or last sentence to be the query.
if len(context) == 2:
if len(block) == 2:
rand_sent_idx = int(self.rng.random() > 0.5)
else:
rand_sent_idx = self.rng.randint(1, len(context) - 2)
rand_sent_idx = self.rng.randint(1, len(block) - 2)
# keep the query in the context 10% of the time.
# keep the query in the block 10% of the time.
if self.rng.random() < 0.1:
input = context[rand_sent_idx].copy()
query = block[rand_sent_idx].copy()
else:
input = context.pop(rand_sent_idx)
query = block.pop(rand_sent_idx)
# may still need to truncate because blocks are concluded when
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
input = input[:self.max_seq_length - 2]
context = list(itertools.chain(*context))[:self.max_seq_length - (3 + len(title))]
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(input)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(context, title)
query_tokens, query_token_types, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_token_types, block_pad_mask = self.concat_and_pad_tokens(block, title)
sample = {
'input_text': np.array(input_tokens),
'input_types': np.array(input_token_types),
'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask)
'query_tokens': np.array(query_tokens),
'query_types': np.array(query_token_types),
'query_pad_mask': np.array(query_pad_mask),
'block_tokens': np.array(block_tokens),
'block_types': np.array(block_token_types),
'block_pad_mask': np.array(block_pad_mask)
}
return sample
......@@ -97,7 +96,7 @@ class InverseClozeDataset(Dataset):
return tokens, token_types, pad_mask
def get_samples_mapping(context_dataset,
def get_samples_mapping(block_dataset,
titles_dataset,
data_prefix,
num_epochs,
......@@ -131,8 +130,8 @@ def get_samples_mapping(context_dataset,
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert context_dataset.doc_idx.dtype == np.int64
assert context_dataset.sizes.dtype == np.int32
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
......@@ -140,8 +139,8 @@ def get_samples_mapping(context_dataset,
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
samples_mapping = helpers.build_blocks_mapping(
context_dataset.doc_idx,
context_dataset.sizes,
block_dataset.doc_idx,
block_dataset.sizes,
titles_dataset.sizes,
num_epochs,
max_num_samples,
......
......@@ -918,10 +918,10 @@ class InverseClozeDataset(data.Dataset):
sample = {
'input_text': np.array(input_tokens),
'input_types': np.array(input_token_types),
'query_types': np.array(input_token_types),
'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types),
'block_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask)
}
......
......@@ -215,6 +215,7 @@ class BertModel(MegatronModule):
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
ict_head_size,
num_tokentypes=0,
......@@ -227,41 +228,38 @@ class ICTBertModel(MegatronModule):
parallel_output=parallel_output
)
self.question_model = BertModel(**bert_args)
self._question_key = 'question_model'
self.context_model = BertModel(**bert_args)
self._context_key = 'context_model'
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = BertModel(**bert_args)
self._query_key = 'question_model'
def forward(self, input_tokens, input_attention_mask, input_types,
context_tokens, context_attention_mask, context_types, return_logits=False):
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types)
context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types)
def forward(self, query_tokens, query_attention_mask, query_types,
block_tokens, block_attention_mask, block_types):
"""Run a forward pass for each of the models and compute the similarity scores."""
# [batch x h] * [h x batch]
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1))
if return_logits:
return question_ict_logits, context_ict_logits, retrieval_scores
return retrieval_scores
query_logits, _ = self.query_model.forward(query_tokens, 1 - query_attention_mask, query_types)
block_logits, _ = self.block_model.forward(block_tokens, 1 - block_attention_mask, block_types)
return query_logits, block_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
state_dict_[self._question_key] \
= self.question_model.state_dict_for_save_checkpoint(
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._context_key] \
= self.context_model.state_dict_for_save_checkpoint(
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.question_model.load_state_dict(
state_dict[self._question_key], strict=strict)
self.context_model.load_state_dict(
state_dict[self._context_key], strict=strict)
"""Load the state dicts of each of the models"""
self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
......@@ -262,19 +262,16 @@ def train_step(forward_step_func, data_iterator,
timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop()
torch.cuda.synchronize()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, loss)
timers('backward').stop()
torch.cuda.synchronize()
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
torch.cuda.synchronize()
# Update learning rate.
skipped_iter = 0
......
......@@ -25,7 +25,6 @@ from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import ICTBertModel
from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses
num_batches = 0
......@@ -46,8 +45,8 @@ def model_provider():
def get_batch(data_iterator):
# Items and their type.
keys = ['input_text', 'input_types', 'input_pad_mask',
'context_text', 'context_types', 'context_pad_mask']
keys = ['query_tokens', 'query_types', 'query_pad_mask',
'block_tokens', 'block_types', 'block_pad_mask']
datatype = torch.int64
# Broadcast data.
......@@ -58,15 +57,15 @@ def get_batch(data_iterator):
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
input_tokens = data_b['input_text'].long()
input_types = data_b['input_types'].long()
input_pad_mask = data_b['input_pad_mask'].long()
context_tokens = data_b['context_text'].long()
context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long()
query_tokens = data_b['query_tokens'].long()
query_types = data_b['query_types'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_types = data_b['block_types'].long()
block_pad_mask = data_b['block_pad_mask'].long()
return input_tokens, input_types, input_pad_mask,\
context_tokens, context_types, context_pad_mask
return query_tokens, query_types, query_pad_mask,\
block_tokens, block_types, block_pad_mask
def forward_step(data_iterator, model):
......@@ -75,15 +74,18 @@ def forward_step(data_iterator, model):
# Get the batch.
timers('batch generator').start()
input_tokens, input_types, input_pad_mask,\
context_tokens, context_types, context_pad_mask = get_batch(data_iterator)
query_tokens, query_types, query_pad_mask,\
block_tokens, block_types, block_pad_mask = get_batch(data_iterator)
timers('batch generator').stop()
# Forward model.
retrieval_scores = model(input_tokens, input_pad_mask, input_types,
context_tokens, context_pad_mask, context_types).float()
query_logits, block_logits = model(query_tokens, query_pad_mask, query_types,
block_tokens, block_pad_mask, block_types).float()
# [batch x h] * [h x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
softmaxed = F.softmax(retrieval_scores, dim=1)
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
batch_size = softmaxed.shape[0]
......@@ -98,71 +100,29 @@ def forward_step(data_iterator, model):
'top5_acc': reduced_losses[2]}
def get_train_val_test_data():
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
(train_data, valid_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
data_parallel_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * data_parallel_size
# Number of train/valid/test samples.
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_iters + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
print_rank_0("> finished creating BERT ICT datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
print_rank_0("> finished creating BERT ICT datasets ...")
return train_data, valid_data, test_data
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step,
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment