Commit 4ac91436 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'indexer-merge' into 'master'

REALM indexer and index data structures

See merge request ADLR/megatron-lm!96
parents 07ebf714 33a666d4
...@@ -319,6 +319,31 @@ python pretrain_ict.py \ ...@@ -319,6 +319,31 @@ python pretrain_ict.py \
</pre> </pre>
### Building an Index of Block Embeddings
After having trained an ICT model, you can now embed an entire dataset of blocks by creating a `BlockData` structure. After that has been saved, you can load it
and wrap it with a `FaissMIPSIndex` to do fast similarity search which is key in the learned information retrieval pipeline. The initial index can be built with the following script, meant to be run in an interactive session. It can leverage multiple GPUs on multiple nodes to index large datasets much more quickly.
<pre>
python tools/create_doc_index.py \
--num-layers 12 \
--hidden-size 768 \
--ict-head-size 128 \
--num-attention-heads 12 \
--batch-size 128 \
--checkpoint-activations \
--seq-length 256 \
--max-position-embeddings 256 \
--ict-load /path/to/pretrained_ict \
--data-path /path/to/indexed_dataset \
--titles-data-path /path/to/titles_indexed_dataset \
--block-data-path embedded_blocks.pkl \
--indexer-log-interval 1000 \
--indexer-batch-size 128 \
--vocab-file /path/to/vocab.txt \
--num-workers 2 \
--fp16
</pre>
<a id="evaluation-and-tasks"></a> <a id="evaluation-and-tasks"></a>
# Evaluation and Tasks # Evaluation and Tasks
......
...@@ -411,12 +411,23 @@ def _add_realm_args(parser): ...@@ -411,12 +411,23 @@ def _add_realm_args(parser):
help='Path to titles dataset used for ICT') help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1, group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset') help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true', group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
# training # training
group.add_argument('--report-topk-accuracies', nargs='+', default=[], group.add_argument('--report-topk-accuracies', nargs='+', default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')") help="Which top-k accuracies to report (e.g. '1 5 20')")
# faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
help='How large of batches to use when doing indexing jobs')
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress')
return parser return parser
...@@ -21,9 +21,9 @@ import sys ...@@ -21,9 +21,9 @@ import sys
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import mpu from megatron import mpu, get_args
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -244,3 +244,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -244,3 +244,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return iteration return iteration
def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
args = get_args()
if isinstance(model, torchDDP):
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
...@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs=None, num_epochs=None,
max_num_samples=train_valid_test_num_samples[index], max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed seed=seed
) )
...@@ -427,13 +426,14 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -427,13 +426,14 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob, query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.ict_one_sent, use_one_sent_docs=args.use_one_sent_docs,
**kwargs **kwargs
) )
else: else:
dataset = BertDataset( dataset = BertDataset(
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob, masked_lm_prob=masked_lm_prob,
short_seq_prob=short_seq_prob,
**kwargs **kwargs
) )
......
...@@ -5,21 +5,47 @@ import numpy as np ...@@ -5,21 +5,47 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping from megatron.data.realm_dataset_utils import get_block_samples_mapping
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=args.use_one_sent_docs
)
dataset = ICTDataset(**kwargs)
return dataset
class ICTDataset(Dataset): class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
query_in_block_prob, short_seq_prob, seed, use_titles=True, use_one_sent_docs=False): seed, use_titles=True, use_one_sent_docs=False):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset self.block_dataset = block_dataset
self.title_dataset = title_dataset self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
self.use_titles = use_titles self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs self.use_one_sent_docs = use_one_sent_docs
...@@ -36,11 +62,13 @@ class ICTDataset(Dataset): ...@@ -36,11 +62,13 @@ class ICTDataset(Dataset):
self.pad_id = self.tokenizer.pad self.pad_id = self.tokenizer.pad
def __len__(self): def __len__(self):
return self.samples_mapping.shape[0] return len(self.samples_mapping)
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted""" """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx] sample_data = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
if self.use_titles: if self.use_titles:
title = self.title_dataset[int(doc_idx)] title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title) title_pad_offset = 3 + len(title)
...@@ -48,7 +76,7 @@ class ICTDataset(Dataset): ...@@ -48,7 +76,7 @@ class ICTDataset(Dataset):
title = None title = None
title_pad_offset = 2 title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)] block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng # randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1) rand_sent_idx = self.rng.randint(0, len(block) - 1)
...@@ -66,7 +94,7 @@ class ICTDataset(Dataset): ...@@ -66,7 +94,7 @@ class ICTDataset(Dataset):
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64) block_data = sample_data.as_array()
sample = { sample = {
'query_tokens': query_tokens, 'query_tokens': query_tokens,
......
...@@ -5,6 +5,58 @@ import numpy as np ...@@ -5,6 +5,58 @@ import numpy as np
import torch import torch
from megatron import mpu, print_rank_0 from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron.data.samplers import DistributedBatchSampler
from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def join_str_list(str_list): def join_str_list(str_list):
...@@ -18,10 +70,50 @@ def join_str_list(str_list): ...@@ -18,10 +70,50 @@ def join_str_list(str_list):
return result return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False): max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires """Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.""" a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
raise ValueError("Need to specify either max_num_samples " raise ValueError("Need to specify either max_num_samples "
...@@ -58,27 +150,33 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -58,27 +150,33 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
start_time = time.time() start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format( print_rank_0(' > building samples index mapping for {} ...'.format(
name)) name))
# compile/bind the C++ helper code
from megatron.data.dataset_utils import compile_helper from megatron.data.dataset_utils import compile_helper
compile_helper() compile_helper()
from megatron.data import helpers from megatron.data import helpers
samples_mapping = helpers.build_blocks_mapping( mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx, block_dataset.doc_idx,
block_dataset.sizes, block_dataset.sizes,
title_dataset.sizes, title_dataset.sizes,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
max_seq_length-3, # account for added tokens max_seq_length - 3, # account for added tokens
seed, seed,
verbose, verbose,
use_one_sent_docs) use_one_sent_docs)
print_rank_0(' > done building samples index mapping') print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True) np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format( print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename)) indexmap_filename))
# Make sure all the ranks have built the mapping # Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping ' print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format( '(seconds): {:4f}'.format(
time.time() - start_time)) time.time() - start_time))
# This should be a barrier but nccl barrier assumes # This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model # device_index=rank which is not the case for model
# parallel case # parallel case
...@@ -91,10 +189,13 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -91,10 +189,13 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' > loading indexed mapping from {}'.format( print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename)) indexmap_filename))
start_time = time.time() start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0])) mapping_array.shape[0]))
return samples_mapping return samples_mapping
import itertools
import os
import pickle
import shutil
import numpy as np
import torch
from megatron import get_args
from megatron import mpu
def detach(tensor):
return tensor.detach().cpu().numpy()
class BlockData(object):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def __init__(self, block_data_path=None, load_from_path=True, rank=None):
self.embed_data = dict()
self.meta_data = dict()
if block_data_path is None:
args = get_args()
block_data_path = args.block_data_path
rank = args.rank
self.block_data_path = block_data_path
self.rank = rank
if load_from_path:
self.load_from_file()
block_data_name = os.path.splitext(self.block_data_path)[0]
self.temp_dir_name = block_data_name + '_tmp'
def state(self):
return {
'embed_data': self.embed_data,
'meta_data': self.meta_data,
}
def clear(self):
"""Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in dimensionality
so it isn't really worth clearing.
"""
self.embed_data = dict()
def load_from_file(self):
"""Populate members from instance saved to file"""
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.block_data_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data']
self.meta_data = state_dict['meta_data']
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta
def save_shard(self):
"""Save the block data that was created this in this process"""
if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
def merge_shards_and_save(self):
"""Combine all the shards made using self.save_shard()"""
shard_names = os.listdir(self.temp_dir_name)
seen_own_shard = False
for fname in os.listdir(self.temp_dir_name):
shard_rank = int(os.path.splitext(fname)[0])
if shard_rank == self.rank:
seen_own_shard = True
continue
with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
data = pickle.load(f)
old_size = len(self.embed_data)
shard_size = len(data['embed_data'])
# add the shard's data and check to make sure there is no overlap
self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data'])
assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard
# save the consolidated shards and remove temporary directory
with open(self.block_data_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
print("Finished merging {} shards for a total of {} embeds".format(
len(shard_names), len(self.embed_data)), flush=True)
class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def __init__(self, embed_size, block_data=None, use_gpu=False):
self.embed_size = embed_size
self.block_data = block_data
self.use_gpu = use_gpu
self.id_map = dict()
self.block_mips_index = None
self._set_block_index()
def _set_block_index(self):
"""Create a Faiss Flat index with inner product as the metric to search against"""
try:
import faiss
except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if self.use_gpu:
# create resources and config for GpuIndex
res = faiss.StandardGpuResources()
config = faiss.GpuIndexFlatConfig()
config.device = torch.cuda.current_device()
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True)
else:
# CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if self.block_data is not None:
self.add_block_embed_data(self.block_data)
def reset_index(self):
"""Delete existing index and create anew"""
del self.block_mips_index
# reset the block data so that _set_block_index will reload it as well
if self.block_data is not None:
block_data_path = self.block_data.block_data_path
del self.block_data
self.block_data = BlockData(block_data_path)
self._set_block_index()
def add_block_embed_data(self, all_block_data):
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_block_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr = np.float32(np.array(block_embeds))
block_indices_arr = np.array(block_indices)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
# we no longer need the embedding data since it's in the index now
all_block_data.clear()
if self.use_gpu:
self.block_mips_index.add(block_embeds_arr)
else:
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds = np.float32(detach(query_embeds))
if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
# get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices
import torch
import torch.distributed as dist
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import load_ict_checkpoint
from megatron.data.ict_dataset import get_ict_dataset
from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, BlockData
from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.model.realm_model import general_ict_model_provider
from megatron.training import get_model
class IndexBuilder(object):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
def __init__(self):
args = get_args()
self.model = None
self.dataloader = None
self.block_data = None
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
assert not (args.load and args.ict_load)
self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
self.load_attributes()
self.is_main_builder = mpu.get_data_parallel_rank() == 0
self.num_total_builders = mpu.get_data_parallel_world_size()
self.iteration = self.total_processed = 0
def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
model = get_model(lambda: general_ict_model_provider(only_block_model=True))
self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model.eval()
self.dataset = get_ict_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
self.block_data = BlockData(load_from_path=False)
def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress"""
self.iteration += 1
self.total_processed += batch_size * self.num_total_builders
if self.is_main_builder and self.iteration % self.log_interval == 0:
print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)
def build_and_save_index(self):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be
consolidated by the rank 0 process and saved as a final pickled BlockData.
"""
while True:
try:
# batch also has query_tokens and query_pad_data
_, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
except (StopIteration, IndexError):
break
unwrapped_model = self.model
while not hasattr(unwrapped_model, 'embed_block'):
unwrapped_model = unwrapped_model.module
# detach, separate fields and add to BlockData
block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
detached_data = detach(block_sample_data)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices = detached_data[:, 3]
block_metas = detached_data[:, :3]
self.block_data.add_block_data(block_indices, block_logits, block_metas)
self.track_and_report_progress(batch_size=block_tokens.shape[0])
# This process signals to finalize its shard and then synchronize with the other processes
self.block_data.save_shard()
torch.distributed.barrier()
del self.model
# rank 0 process builds the final copy
if self.is_main_builder:
self.block_data.merge_shards_and_save()
# make sure that every single piece of data was embedded
assert len(self.block_data.embed_data) == len(self.dataset)
self.block_data.clear()
import os import os
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel from megatron.model import BertModel
from megatron.module import MegatronModule from megatron.module import MegatronModule
...@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task.""" """Bert-based module for Inverse Cloze task."""
def __init__(self, def __init__(self,
......
...@@ -21,6 +21,7 @@ from .data import broadcast_data ...@@ -21,6 +21,7 @@ from .data import broadcast_data
from .grads import clip_grad_norm from .grads import clip_grad_norm
from .initialize import is_unitialized
from .initialize import destroy_model_parallel from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank from .initialize import get_data_parallel_rank
......
...@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None ...@@ -31,6 +31,11 @@ _MPU_WORLD_SIZE = None
_MPU_RANK = None _MPU_RANK = None
def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None
def initialize_model_parallel(model_parallel_size_): def initialize_model_parallel(model_parallel_size_):
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
......
...@@ -24,37 +24,14 @@ from megatron import print_rank_0 ...@@ -24,37 +24,14 @@ from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ICTBertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
num_batches = 0
def pretrain_ict_model_provider():
def general_model_provider(only_query_model=False, only_block_model=False): return general_ict_model_provider(False, False)
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
def model_provider():
return general_model_provider(False, False)
def get_group_world_size_rank(): def get_group_world_size_rank():
...@@ -95,30 +72,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -95,30 +72,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output return output
def get_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
...@@ -127,7 +80,7 @@ def forward_step(data_iterator, model): ...@@ -127,7 +80,7 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
...@@ -181,5 +134,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -181,5 +134,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
import sys
sys.path.append('../')
from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron
def main():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment