Commit b9bd1a11 authored by Neel Kant's avatar Neel Kant
Browse files

Additional refactoring

parent ca0cdfaa
import os
import sys
import time
import torch
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args, get_adlr_autoresume, print_rank_0
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_dataset_utils import BlockSampleData
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.ict_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import get_batch, general_ict_model_provider
from pretrain_ict import get_batch, general_ict_model_provider
def pprint(*args):
......@@ -30,17 +25,21 @@ class IndexBuilder(object):
self.model = None
self.dataloader = None
self.block_data = None
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
assert not (args.load and args.ict_load)
self.using_realm_chkpt = args.ict_load is None
self.load_attributes()
self.is_main_builder = args.rank == 0
self.iteration = self.total_processed = 0
def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
# TODO: handle from_realm_chkpt correctly
self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=False)
self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
self.block_data = BlockData(load_from_path=False)
def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress"""
......@@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
short_seq_prob=0.0001, # doesn't matter
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
......
......@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed
)
......@@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
dataset = BertDataset(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
**kwargs
)
......
import collections
import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
if self.use_titles:
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'block_data': block_data,
}
return sample
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return np.array(tokens), np.array(pad_mask)
from collections import defaultdict
import itertools
import os
import pickle
......@@ -8,7 +7,7 @@ import faiss
import numpy as np
import torch
from megatron import get_args, mpu
from megatron import get_args
def detach(tensor):
......@@ -17,7 +16,7 @@ def detach(tensor):
class BlockData(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def __init__(self, block_data_path=None, rank=None):
def __init__(self, block_data_path=None, load_from_path=True, rank=None):
self.embed_data = dict()
self.meta_data = dict()
if block_data_path is None:
......@@ -27,6 +26,9 @@ class BlockData(object):
self.block_data_path = block_data_path
self.rank = rank
if load_from_path:
self.load_from_file()
block_data_name = os.path.splitext(self.block_data_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
......@@ -43,18 +45,23 @@ class BlockData(object):
"""
self.embed_data = dict()
@classmethod
def load_from_file(cls, fname):
def load_from_file(self):
"""Populate members from instance saved to file"""
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(fname, 'rb'))
state_dict = pickle.load(open(self.block_data_path, 'rb'))
print(">> Finished unpickling BlockData\n", flush=True)
new_index = cls()
new_index.embed_data = state_dict['embed_data']
new_index.meta_data = state_dict['meta_data']
return new_index
self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data']
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
......@@ -63,6 +70,7 @@ class BlockData(object):
self.meta_data[idx] = meta
def save_shard(self):
"""Save the block data that was created this in this process"""
if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True)
......@@ -104,9 +112,9 @@ class BlockData(object):
class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def __init__(self, index_type, embed_size, use_gpu=False):
self.index_type = index_type
def __init__(self, embed_size, block_data=None, use_gpu=False):
self.embed_size = embed_size
self.block_data = block_data
self.use_gpu = use_gpu
self.id_map = dict()
......@@ -114,10 +122,7 @@ class FaissMIPSIndex(object):
self._set_block_index()
def _set_block_index(self):
INDEX_TYPES = ['flat_ip']
if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified")
"""Create a Faiss Flat index with inner product as the metric to search against"""
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
......@@ -129,29 +134,52 @@ class FaissMIPSIndex(object):
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Finished building index\n", flush=True)
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if self.block_data is not None:
self.add_block_embed_data(self.block_data)
def reset_index(self):
"""Delete existing index and create anew"""
del self.block_mips_index
# reset the block data so that _set_block_index will reload it as well
if self.block_data is not None:
block_data_path = self.block_data.block_data_path
del self.block_data
self.block_data = BlockData.load_from_file(block_data_path)
self._set_block_index()
def add_block_embed_data(self, all_block_data):
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_block_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr = np.float32(np.array(block_embeds))
block_indices_arr = np.array(block_indices)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
# we no longer need the embedding data since it's in the index now
all_block_data.clear()
if self.use_gpu:
self.block_mips_index.add(np.float32(np.array(block_embeds)))
self.block_mips_index.add(block_embeds_arr)
else:
self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric.
......@@ -160,12 +188,15 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds = np.float32(detach(query_embeds))
with torch.no_grad():
if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
# get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
......
......@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer.
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True)
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True)
elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False)
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
......@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'.format(self.name))
class BertWordPieceTokenizer(AbstractTokenizer):
class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True):
......
......@@ -74,11 +74,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args()
timers = get_timers()
if args.rank == 0 and args.cased_data_path is not None:
import stanza
stanza.download('en', processors={'ner': 'conll03'}, dir='stanza')
# Model, optimizer, and learning rate.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
......@@ -237,8 +232,6 @@ def setup_model_and_optimizer(model_provider_func):
def backward_step(optimizer, model, loss):
"""Backward step."""
# if args.rank == 0:
# torch.save(lick)
args = get_args()
timers = get_timers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment