Commit 4abd7ce2 authored by Neel Kant's avatar Neel Kant
Browse files

Refactor before merge request

parent f1ad8c94
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider
def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
model = load_checkpoint()
model.eval()
dataset = get_dataset()
data_iter = iter(get_dataloader(dataset))
all_input_tokens = []
all_input_logits = []
all_block_tokens = []
all_block_logits = []
for i in range(100):
input_tokens, input_types, input_pad_mask, block_tokens, block_token_types, block_pad_mask = get_batch(data_iter)
input_logits, doc_logits, _ = model.module.module.forward(
input_tokens, input_types, input_pad_mask, block_tokens, block_pad_mask, block_token_types, return_logits=True)
all_input_tokens.append(input_tokens.detach().cpu().numpy())
all_input_logits.append(input_logits.detach().cpu().numpy())
all_block_tokens.append(block_tokens.detach().cpu().numpy())
all_block_logits.append(doc_logits.detach().cpu().numpy())
all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
all_input_logits = np.array(all_input_logits).reshape(-1, 128)
all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
all_block_logits = np.array(all_block_logits).reshape(-1, 128)
np.save('input_tokens.npy', all_input_tokens)
np.save('input_logits.npy', all_input_logits)
np.save('block_tokens.npy', all_block_tokens)
np.save('doc_logits.npy', all_block_logits)
def load_checkpoint():
args = get_args()
model = get_model(model_provider)
if isinstance(model, torchDDP):
model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.load)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
model.load_state_dict(state_dict['model'])
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def get_dataset():
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
doc_idx_ptr = block_dataset.get_doc_idx()
total_num_documents = block_dataset.doc_idx.shape[0] - 1
block_dataset.set_doc_idx(doc_idx_ptr[0:total_num_documents])
kwargs = dict(
name='full',
context_dataset=block_dataset,
titles_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=None,
max_num_samples=total_num_documents * 3,
max_seq_length=288, # doesn't matter
short_seq_prob=0.0001, # doesn't matter
seed=1
)
dataset = InverseClozeDataset(**kwargs)
return dataset
def get_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
main()
...@@ -42,9 +42,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -42,9 +42,9 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
skip_warmup) skip_warmup)
if ict_dataset: if ict_dataset:
titles_dataset = get_indexed_dataset_(data_prefix + '-titles', title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl, data_impl,
skip_warmup) skip_warmup)
# Get start and end indices of train/valid/train into doc-idx # Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can # Note that doc-idx is desinged to be num-docs + 1 so we can
...@@ -54,6 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -54,6 +54,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Print stats about the splits. # Print stats about the splits.
print_rank_0(' > dataset split:') print_rank_0(' > dataset split:')
def print_split_stats(name, index): def print_split_stats(name, index):
print_rank_0(' {}:'.format(name)) print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} ' print_rank_0(' document indices in [{}, {}) total of {} '
...@@ -82,7 +83,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -82,7 +83,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
# Build the dataset accordingly. # Build the dataset accordingly.
kwargs = dict( kwargs = dict(
name=name, name=name,
context_dataset=indexed_dataset,
data_prefix=data_prefix, data_prefix=data_prefix,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
...@@ -92,9 +92,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -92,9 +92,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
) )
if ict_dataset: if ict_dataset:
dataset = InverseClozeDataset(titles_dataset=titles_dataset, **kwargs) dataset = InverseClozeDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
**kwargs
)
else: else:
dataset = BertDataset(masked_lm_prob=masked_lm_prob, **kwargs) dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset. # Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr) indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks. # Checks.
......
...@@ -452,6 +452,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -452,6 +452,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index. // Current map index.
uint64_t map_index = 0; uint64_t map_index = 0;
int32_t block_id = 0;
// For each epoch: // For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) { for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
...@@ -514,14 +515,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -514,14 +515,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Populate the map. // Populate the map.
if (second) { if (second) {
const auto map_index_0 = 3 * map_index; const auto map_index_0 = 4 * map_index;
maps[map_index_0] = static_cast<DocIdx>(prev_start_index); maps[map_index_0] = static_cast<DocIdx>(prev_start_index);
maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1); maps[map_index_0 + 1] = static_cast<DocIdx>(sent_index + 1);
maps[map_index_0 + 2] = static_cast<DocIdx>(doc); maps[map_index_0 + 2] = static_cast<DocIdx>(doc);
maps[map_index_0 + 3] = static_cast<DocIdx>(block_id);
} }
// Update indices / counters. // Update indices / counters.
++map_index; ++map_index;
++block_id;
prev_start_index = sent_index + 1; prev_start_index = sent_index + 1;
seq_len = 0; seq_len = 0;
num_sent = 0; num_sent = 0;
...@@ -529,6 +532,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -529,6 +532,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (auto sent_index=sent_index_first; ... } // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) { } // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) { } // for (int doc=0; doc < num_docs; ++doc) {
block_id = 0;
} // for (int epoch=0; epoch < num_epochs; ++epoch) { } // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) { if (!second) {
...@@ -538,7 +542,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -538,7 +542,7 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} }
assert(maps == NULL); assert(maps == NULL);
assert(num_samples < 0); assert(num_samples < 0);
maps = new DocIdx[3*map_index]; maps = new DocIdx[4*map_index];
num_samples = static_cast<int64_t>(map_index); num_samples = static_cast<int64_t>(map_index);
} }
...@@ -550,12 +554,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -550,12 +554,13 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
std::mt19937_64 rand64_gen(seed + 1); std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) { for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1)); const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
const auto i0 = 3 * i; const auto i0 = 4 * i;
const auto j0 = 3 * j; const auto j0 = 4 * j;
// Swap values. // Swap values.
swap(maps[i0], maps[j0]); swap(maps[i0], maps[j0]);
swap(maps[i0 + 1], maps[j0 + 1]); swap(maps[i0 + 1], maps[j0 + 1]);
swap(maps[i0 + 2], maps[j0 + 2]); swap(maps[i0 + 2], maps[j0 + 2]);
swap(maps[i0 + 3], maps[j0 + 3]);
} }
// Method to deallocate memory. // Method to deallocate memory.
...@@ -566,8 +571,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_, ...@@ -566,8 +571,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Return the numpy array. // Return the numpy array.
const auto byte_size = sizeof(DocIdx); const auto byte_size = sizeof(DocIdx);
return py::array(std::vector<int64_t>{num_samples, 3}, // shape return py::array(std::vector<int64_t>{num_samples, 4}, // shape
{3*byte_size, byte_size}, // C-style contiguous strides {4*byte_size, byte_size}, // C-style contiguous strides
maps, // the data pointer maps, // the data pointer
free_when_done); // numpy array references free_when_done); // numpy array references
......
import itertools import itertools
import random import random
import os import os
import sys
import time import time
import numpy as np import numpy as np
...@@ -16,19 +15,19 @@ from megatron.data import helpers ...@@ -16,19 +15,19 @@ from megatron.data import helpers
class InverseClozeDataset(Dataset): class InverseClozeDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, context_dataset, titles_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed): short_seq_prob, seed):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.context_dataset = context_dataset self.block_dataset = block_dataset
self.titles_dataset = titles_dataset self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
self.samples_mapping = get_samples_mapping(self.context_dataset, self.samples_mapping = get_samples_mapping(self.block_dataset,
self.titles_dataset, self.title_dataset,
data_prefix, data_prefix,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
...@@ -47,38 +46,38 @@ class InverseClozeDataset(Dataset): ...@@ -47,38 +46,38 @@ class InverseClozeDataset(Dataset):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx, end_idx, doc_idx = self.samples_mapping[idx] start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.titles_dataset[int(doc_idx)]) title = list(self.title_dataset[int(doc_idx)])
context = [list(self.context_dataset[i]) for i in range(start_idx, end_idx)] block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(context) > 1 assert len(block) > 1
# avoid selecting the first or last sentence to be the query. # avoid selecting the first or last sentence to be the query.
if len(context) == 2: if len(block) == 2:
rand_sent_idx = int(self.rng.random() > 0.5) rand_sent_idx = int(self.rng.random() > 0.5)
else: else:
rand_sent_idx = self.rng.randint(1, len(context) - 2) rand_sent_idx = self.rng.randint(1, len(block) - 2)
# keep the query in the context 10% of the time. # keep the query in the block 10% of the time.
if self.rng.random() < 0.1: if self.rng.random() < 0.1:
input = context[rand_sent_idx].copy() query = block[rand_sent_idx].copy()
else: else:
input = context.pop(rand_sent_idx) query = block.pop(rand_sent_idx)
# may still need to truncate because blocks are concluded when # still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length. # the sentence lengths have exceeded max_seq_length.
input = input[:self.max_seq_length - 2] query = query[:self.max_seq_length - 2]
context = list(itertools.chain(*context))[:self.max_seq_length - (3 + len(title))] block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(input) query_tokens, query_token_types, query_pad_mask = self.concat_and_pad_tokens(query)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(context, title) block_tokens, block_token_types, block_pad_mask = self.concat_and_pad_tokens(block, title)
sample = { sample = {
'input_text': np.array(input_tokens), 'query_tokens': np.array(query_tokens),
'input_types': np.array(input_token_types), 'query_types': np.array(query_token_types),
'input_pad_mask': np.array(input_pad_mask), 'query_pad_mask': np.array(query_pad_mask),
'context_text': np.array(context_tokens), 'block_tokens': np.array(block_tokens),
'context_types': np.array(context_token_types), 'block_types': np.array(block_token_types),
'context_pad_mask': np.array(context_pad_mask) 'block_pad_mask': np.array(block_pad_mask)
} }
return sample return sample
...@@ -97,7 +96,7 @@ class InverseClozeDataset(Dataset): ...@@ -97,7 +96,7 @@ class InverseClozeDataset(Dataset):
return tokens, token_types, pad_mask return tokens, token_types, pad_mask
def get_samples_mapping(context_dataset, def get_samples_mapping(block_dataset,
titles_dataset, titles_dataset,
data_prefix, data_prefix,
num_epochs, num_epochs,
...@@ -131,8 +130,8 @@ def get_samples_mapping(context_dataset, ...@@ -131,8 +130,8 @@ def get_samples_mapping(context_dataset,
'the indices on rank 0 ...'.format(indexmap_filename)) 'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types. # Make sure the types match the helpers input types.
assert context_dataset.doc_idx.dtype == np.int64 assert block_dataset.doc_idx.dtype == np.int64
assert context_dataset.sizes.dtype == np.int32 assert block_dataset.sizes.dtype == np.int32
# Build samples mapping # Build samples mapping
verbose = torch.distributed.get_rank() == 0 verbose = torch.distributed.get_rank() == 0
...@@ -140,8 +139,8 @@ def get_samples_mapping(context_dataset, ...@@ -140,8 +139,8 @@ def get_samples_mapping(context_dataset,
print_rank_0(' > building samples index mapping for {} ...'.format( print_rank_0(' > building samples index mapping for {} ...'.format(
name)) name))
samples_mapping = helpers.build_blocks_mapping( samples_mapping = helpers.build_blocks_mapping(
context_dataset.doc_idx, block_dataset.doc_idx,
context_dataset.sizes, block_dataset.sizes,
titles_dataset.sizes, titles_dataset.sizes,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
......
...@@ -918,10 +918,10 @@ class InverseClozeDataset(data.Dataset): ...@@ -918,10 +918,10 @@ class InverseClozeDataset(data.Dataset):
sample = { sample = {
'input_text': np.array(input_tokens), 'input_text': np.array(input_tokens),
'input_types': np.array(input_token_types), 'query_types': np.array(input_token_types),
'input_pad_mask': np.array(input_pad_mask), 'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens), 'context_text': np.array(context_tokens),
'context_types': np.array(context_token_types), 'block_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask) 'context_pad_mask': np.array(context_pad_mask)
} }
......
...@@ -215,6 +215,7 @@ class BertModel(MegatronModule): ...@@ -215,6 +215,7 @@ class BertModel(MegatronModule):
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self, def __init__(self,
ict_head_size, ict_head_size,
num_tokentypes=0, num_tokentypes=0,
...@@ -227,41 +228,38 @@ class ICTBertModel(MegatronModule): ...@@ -227,41 +228,38 @@ class ICTBertModel(MegatronModule):
parallel_output=parallel_output parallel_output=parallel_output
) )
self.question_model = BertModel(**bert_args) # this model embeds (pseudo-)queries - Embed_input in the paper
self._question_key = 'question_model' self.query_model = BertModel(**bert_args)
self.context_model = BertModel(**bert_args) self._query_key = 'question_model'
self._context_key = 'context_model'
def forward(self, input_tokens, input_attention_mask, input_types, # this model embeds evidence blocks - Embed_doc in the paper
context_tokens, context_attention_mask, context_types, return_logits=False): self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
question_ict_logits, _ = self.question_model.forward(input_tokens, 1 - input_attention_mask, input_types) def forward(self, query_tokens, query_attention_mask, query_types,
context_ict_logits, _ = self.context_model.forward(context_tokens, 1 - context_attention_mask, context_types) block_tokens, block_attention_mask, block_types):
"""Run a forward pass for each of the models and compute the similarity scores."""
# [batch x h] * [h x batch] query_logits, _ = self.query_model.forward(query_tokens, 1 - query_attention_mask, query_types)
retrieval_scores = question_ict_logits.matmul(torch.transpose(context_ict_logits, 0, 1)) block_logits, _ = self.block_model.forward(block_tokens, 1 - block_attention_mask, block_types)
if return_logits:
return question_ict_logits, context_ict_logits, retrieval_scores
return retrieval_scores
return query_logits, block_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._question_key] \ state_dict_[self._query_key] \
= self.question_model.state_dict_for_save_checkpoint( = self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._context_key] \ state_dict_[self._block_key] \
= self.context_model.state_dict_for_save_checkpoint( = self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Load the state dicts of each of the models"""
self.query_model.load_state_dict(
self.question_model.load_state_dict( state_dict[self._query_key], strict=strict)
state_dict[self._question_key], strict=strict) self.block_model.load_state_dict(
self.context_model.load_state_dict( state_dict[self._block_key], strict=strict)
state_dict[self._context_key], strict=strict)
...@@ -262,19 +262,16 @@ def train_step(forward_step_func, data_iterator, ...@@ -262,19 +262,16 @@ def train_step(forward_step_func, data_iterator,
timers('forward').start() timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model) loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop() timers('forward').stop()
torch.cuda.synchronize()
# Calculate gradients, reduce across processes, and clip. # Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
backward_step(optimizer, model, loss) backward_step(optimizer, model, loss)
timers('backward').stop() timers('backward').stop()
torch.cuda.synchronize()
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
torch.cuda.synchronize()
# Update learning rate. # Update learning rate.
skipped_iter = 0 skipped_iter = 0
......
...@@ -25,7 +25,6 @@ from megatron import print_rank_0 ...@@ -25,7 +25,6 @@ from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import ICTBertModel from megatron.model import ICTBertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import make_data_loader
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
num_batches = 0 num_batches = 0
...@@ -46,8 +45,8 @@ def model_provider(): ...@@ -46,8 +45,8 @@ def model_provider():
def get_batch(data_iterator): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['input_text', 'input_types', 'input_pad_mask', keys = ['query_tokens', 'query_types', 'query_pad_mask',
'context_text', 'context_types', 'context_pad_mask'] 'block_tokens', 'block_types', 'block_pad_mask']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
...@@ -58,15 +57,15 @@ def get_batch(data_iterator): ...@@ -58,15 +57,15 @@ def get_batch(data_iterator):
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
input_tokens = data_b['input_text'].long() query_tokens = data_b['query_tokens'].long()
input_types = data_b['input_types'].long() query_types = data_b['query_types'].long()
input_pad_mask = data_b['input_pad_mask'].long() query_pad_mask = data_b['query_pad_mask'].long()
context_tokens = data_b['context_text'].long() block_tokens = data_b['block_tokens'].long()
context_types = data_b['context_types'].long() block_types = data_b['block_types'].long()
context_pad_mask = data_b['context_pad_mask'].long() block_pad_mask = data_b['block_pad_mask'].long()
return input_tokens, input_types, input_pad_mask,\ return query_tokens, query_types, query_pad_mask,\
context_tokens, context_types, context_pad_mask block_tokens, block_types, block_pad_mask
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
...@@ -75,15 +74,18 @@ def forward_step(data_iterator, model): ...@@ -75,15 +74,18 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
input_tokens, input_types, input_pad_mask,\ query_tokens, query_types, query_pad_mask,\
context_tokens, context_types, context_pad_mask = get_batch(data_iterator) block_tokens, block_types, block_pad_mask = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
retrieval_scores = model(input_tokens, input_pad_mask, input_types, query_logits, block_logits = model(query_tokens, query_pad_mask, query_types,
context_tokens, context_pad_mask, context_types).float() block_tokens, block_pad_mask, block_types).float()
# [batch x h] * [h x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
softmaxed = F.softmax(retrieval_scores, dim=1) softmaxed = F.softmax(retrieval_scores, dim=1)
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True) top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
batch_size = softmaxed.shape[0] batch_size = softmaxed.shape[0]
...@@ -98,71 +100,29 @@ def forward_step(data_iterator, model): ...@@ -98,71 +100,29 @@ def forward_step(data_iterator, model):
'top5_acc': reduced_losses[2]} 'top5_acc': reduced_losses[2]}
def get_train_val_test_data(): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Load the data on rank zero and boradcast number of tokens to all GPUS.""" """Build train, valid and test datasets."""
args = get_args() args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
(train_data, valid_data, test_data) = (None, None, None) train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
# Data loader only on rank 0 of each model parallel group. data_impl=args.data_impl,
if mpu.get_model_parallel_rank() == 0: splits_string=args.split,
print_rank_0('> building train, validation, and test datasets ' train_valid_test_num_samples=train_val_test_num_samples,
'for BERT ...') max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
data_parallel_size = mpu.get_data_parallel_world_size() short_seq_prob=args.short_seq_prob,
data_parallel_rank = mpu.get_data_parallel_rank() seed=args.seed,
global_batch_size = args.batch_size * data_parallel_size skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
# Number of train/valid/test samples. print_rank_0("> finished creating BERT ICT datasets ...")
train_iters = args.train_iters
eval_iters = (train_iters // args.eval_iters + 1) * args.eval_iters
test_iters = args.eval_iters
train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):')
print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
print_rank_0("> finished creating BERT ICT datasets ...")
train_data = make_data_loader(train_ds)
valid_data = make_data_loader(valid_ds)
test_data = make_data_loader(test_ds)
do_train = train_data is not None and args.train_iters > 0
do_valid = valid_data is not None and args.eval_iters > 0
do_test = test_data is not None and args.eval_iters > 0
# Need to broadcast num_tokens and num_type_tokens.
flags = torch.cuda.LongTensor(
[int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_data, valid_data, test_data return train_ds, valid_ds, test_ds
if __name__ == "__main__": if __name__ == "__main__":
pretrain(get_train_val_test_data, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment