Commit b9bd1a11 authored by Neel Kant's avatar Neel Kant
Browse files

Additional refactoring

parent ca0cdfaa
import os
import sys
import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args, get_adlr_autoresume, print_rank_0 from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset import ICTDataset from megatron.data.ict_dataset import ICTDataset
from megatron.data.realm_dataset_utils import BlockSampleData from megatron.data.realm_index import detach, BlockData
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.training import get_model from megatron.training import get_model
from pretrain_bert_ict import get_batch, general_ict_model_provider from pretrain_ict import get_batch, general_ict_model_provider
def pprint(*args): def pprint(*args):
...@@ -30,17 +25,21 @@ class IndexBuilder(object): ...@@ -30,17 +25,21 @@ class IndexBuilder(object):
self.model = None self.model = None
self.dataloader = None self.dataloader = None
self.block_data = None self.block_data = None
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
assert not (args.load and args.ict_load)
self.using_realm_chkpt = args.ict_load is None
self.load_attributes() self.load_attributes()
self.is_main_builder = args.rank == 0 self.is_main_builder = args.rank == 0
self.iteration = self.total_processed = 0 self.iteration = self.total_processed = 0
def load_attributes(self): def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData""" """Load the necessary attributes: model, dataloader and empty BlockData"""
# TODO: handle from_realm_chkpt correctly self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=False)
self.model.eval() self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset())) self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData() self.block_data = BlockData(load_from_path=False)
def track_and_report_progress(self, batch_size): def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress""" """Utility function for tracking progress"""
...@@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): ...@@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
num_epochs=1, num_epochs=1,
max_num_samples=None, max_num_samples=None,
max_seq_length=args.seq_length, max_seq_length=args.seq_length,
short_seq_prob=0.0001, # doesn't matter
seed=1, seed=1,
query_in_block_prob=query_in_block_prob, query_in_block_prob=query_in_block_prob,
use_titles=use_titles, use_titles=use_titles,
......
...@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed seed=seed
) )
...@@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
dataset = BertDataset( dataset = BertDataset(
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob, masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
**kwargs **kwargs
) )
......
import collections
import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import BlockSampleData, get_block_samples_mapping, join_str_list
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
short_seq_prob, seed, use_titles=True, use_one_sent_docs=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
if self.use_titles:
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'block_data': block_data,
}
return sample
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return np.array(tokens), np.array(pad_mask)
from collections import defaultdict
import itertools import itertools
import os import os
import pickle import pickle
...@@ -8,7 +7,7 @@ import faiss ...@@ -8,7 +7,7 @@ import faiss
import numpy as np import numpy as np
import torch import torch
from megatron import get_args, mpu from megatron import get_args
def detach(tensor): def detach(tensor):
...@@ -17,7 +16,7 @@ def detach(tensor): ...@@ -17,7 +16,7 @@ def detach(tensor):
class BlockData(object): class BlockData(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM""" """Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def __init__(self, block_data_path=None, rank=None): def __init__(self, block_data_path=None, load_from_path=True, rank=None):
self.embed_data = dict() self.embed_data = dict()
self.meta_data = dict() self.meta_data = dict()
if block_data_path is None: if block_data_path is None:
...@@ -27,6 +26,9 @@ class BlockData(object): ...@@ -27,6 +26,9 @@ class BlockData(object):
self.block_data_path = block_data_path self.block_data_path = block_data_path
self.rank = rank self.rank = rank
if load_from_path:
self.load_from_file()
block_data_name = os.path.splitext(self.block_data_path)[0] block_data_name = os.path.splitext(self.block_data_path)[0]
self.temp_dir_name = block_data_name + '_tmp' self.temp_dir_name = block_data_name + '_tmp'
...@@ -43,18 +45,23 @@ class BlockData(object): ...@@ -43,18 +45,23 @@ class BlockData(object):
""" """
self.embed_data = dict() self.embed_data = dict()
@classmethod def load_from_file(self):
def load_from_file(cls, fname): """Populate members from instance saved to file"""
print("\n> Unpickling BlockData", flush=True) print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(fname, 'rb')) state_dict = pickle.load(open(self.block_data_path, 'rb'))
print(">> Finished unpickling BlockData\n", flush=True) print(">> Finished unpickling BlockData\n", flush=True)
new_index = cls() self.embed_data = state_dict['embed_data']
new_index.embed_data = state_dict['embed_data'] self.meta_data = state_dict['meta_data']
new_index.meta_data = state_dict['meta_data']
return new_index
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False): def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for idx, embed, meta in zip(block_indices, block_embeds, block_metas): for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
if not allow_overwrite and idx in self.embed_data: if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data") raise ValueError("Unexpectedly tried to overwrite block data")
...@@ -63,6 +70,7 @@ class BlockData(object): ...@@ -63,6 +70,7 @@ class BlockData(object):
self.meta_data[idx] = meta self.meta_data[idx] = meta
def save_shard(self): def save_shard(self):
"""Save the block data that was created this in this process"""
if not os.path.isdir(self.temp_dir_name): if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True) os.makedirs(self.temp_dir_name, exist_ok=True)
...@@ -104,9 +112,9 @@ class BlockData(object): ...@@ -104,9 +112,9 @@ class BlockData(object):
class FaissMIPSIndex(object): class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood""" """Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def __init__(self, index_type, embed_size, use_gpu=False): def __init__(self, embed_size, block_data=None, use_gpu=False):
self.index_type = index_type
self.embed_size = embed_size self.embed_size = embed_size
self.block_data = block_data
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.id_map = dict() self.id_map = dict()
...@@ -114,10 +122,7 @@ class FaissMIPSIndex(object): ...@@ -114,10 +122,7 @@ class FaissMIPSIndex(object):
self._set_block_index() self._set_block_index()
def _set_block_index(self): def _set_block_index(self):
INDEX_TYPES = ['flat_ip'] """Create a Faiss Flat index with inner product as the metric to search against"""
if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified")
print("\n> Building index", flush=True) print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
...@@ -129,29 +134,52 @@ class FaissMIPSIndex(object): ...@@ -129,29 +134,52 @@ class FaissMIPSIndex(object):
config.useFloat16 = True config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True) print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
else: else:
# CPU index supports IDs so wrap with IDMap # CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Finished building index\n", flush=True) print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if self.block_data is not None:
self.add_block_embed_data(self.block_data)
def reset_index(self): def reset_index(self):
"""Delete existing index and create anew""" """Delete existing index and create anew"""
del self.block_mips_index del self.block_mips_index
# reset the block data so that _set_block_index will reload it as well
if self.block_data is not None:
block_data_path = self.block_data.block_data_path
del self.block_data
self.block_data = BlockData.load_from_file(block_data_path)
self._set_block_index() self._set_block_index()
def add_block_embed_data(self, all_block_data): def add_block_embed_data(self, all_block_data):
"""Add the embedding of each block to the underlying FAISS index""" """Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_block_data.embed_data.items()) block_indices, block_embeds = zip(*all_block_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr = np.float32(np.array(block_embeds))
block_indices_arr = np.array(block_indices)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if self.use_gpu: if self.use_gpu:
for i, idx in enumerate(block_indices): for i, idx in enumerate(block_indices):
self.id_map[i] = idx self.id_map[i] = idx
# we no longer need the embedding data since it's in the index now
all_block_data.clear() all_block_data.clear()
if self.use_gpu: if self.use_gpu:
self.block_mips_index.add(np.float32(np.array(block_embeds))) self.block_mips_index.add(block_embeds_arr)
else: else:
self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices)) self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """Get the top-k blocks by the index distance metric.
...@@ -160,12 +188,15 @@ class FaissMIPSIndex(object): ...@@ -160,12 +188,15 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices if False: return [num_queries x k] array of distances, and another for indices
""" """
query_embeds = np.float32(detach(query_embeds)) query_embeds = np.float32(detach(query_embeds))
with torch.no_grad(): with torch.no_grad():
if reconstruct: if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds return top_k_block_embeds
else: else:
# get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k) distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu: if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape) fresh_indices = np.zeros(block_indices.shape)
......
...@@ -31,11 +31,11 @@ def build_tokenizer(args): ...@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer. # Select and instantiate the tokenizer.
assert args.vocab_file is not None assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase': if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True) lower_case=True)
elif args.tokenizer_type == 'BertWordPieceCase': elif args.tokenizer_type == 'BertWordPieceCase':
tokenizer = BertWordPieceTokenizer(vocab_file=args.vocab_file, tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=False) lower_case=False)
elif args.tokenizer_type == 'GPT2BPETokenizer': elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
...@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC): ...@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'.format(self.name)) 'tokenizer'.format(self.name))
class BertWordPieceTokenizer(AbstractTokenizer): class _BertWordPieceTokenizer(AbstractTokenizer):
"""Original BERT wordpiece tokenizer.""" """Original BERT wordpiece tokenizer."""
def __init__(self, vocab_file, lower_case=True): def __init__(self, vocab_file, lower_case=True):
......
...@@ -74,11 +74,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -74,11 +74,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
if args.rank == 0 and args.cased_data_path is not None:
import stanza
stanza.download('en', processors={'ner': 'conll03'}, dir='stanza')
# Model, optimizer, and learning rate. # Model, optimizer, and learning rate.
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
...@@ -237,8 +232,6 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -237,8 +232,6 @@ def setup_model_and_optimizer(model_provider_func):
def backward_step(optimizer, model, loss): def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
# if args.rank == 0:
# torch.save(lick)
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment