Commit 11f76cd3 authored by Neel Kant's avatar Neel Kant
Browse files

Address comments from last week

parent 7a348580
......@@ -411,7 +411,7 @@ def _add_realm_args(parser):
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true',
group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT')
# training
......@@ -421,7 +421,13 @@ def _add_realm_args(parser):
# faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str,
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
help='How large of batches to use when doing indexing jobs')
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress')
return parser
......@@ -21,9 +21,10 @@ import sys
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu
from megatron import mpu, get_args
from megatron import get_args
from megatron import print_rank_0
......@@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print(' successfully loaded {}'.format(checkpoint_name))
return iteration
def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
args = get_args()
if isinstance(model, torchDDP):
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
\ No newline at end of file
......@@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset=indexed_dataset,
title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.ict_one_sent,
use_one_sent_docs=args.use_one_sent_docs,
**kwargs
)
else:
......
......@@ -5,9 +5,36 @@ import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=args.use_one_sent_docs
)
dataset = ICTDataset(**kwargs)
return dataset
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
......@@ -35,7 +62,7 @@ class ICTDataset(Dataset):
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
return len(self.samples_mapping)
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
......
......@@ -6,9 +6,59 @@ import torch
from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron.data.samplers import DistributedBatchSampler
from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
......@@ -46,10 +96,12 @@ class BlockSamplesMapping(object):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
self.shape = self.mapping_array.shape
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with a particular sample."""
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
......@@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
mapping_array.shape[0]))
return samples_mapping
......@@ -3,7 +3,6 @@ import os
import pickle
import shutil
import faiss
import numpy as np
import torch
......@@ -123,6 +122,11 @@ class FaissMIPSIndex(object):
def _set_block_index(self):
"""Create a Faiss Flat index with inner product as the metric to search against"""
try:
import faiss
except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
......@@ -188,19 +192,18 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds = np.float32(detach(query_embeds))
with torch.no_grad():
if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
# get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices
if reconstruct:
# get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
# get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices
import torch
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.ict_dataset import ICTDataset
from megatron.checkpointing import load_ict_checkpoint
from megatron.data.ict_dataset import get_ict_dataset
from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, BlockData
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.model.realm_model import general_ict_model_provider
from megatron.training import get_model
from pretrain_ict import get_batch, general_ict_model_provider
def pprint(*args):
print(*args, flush=True)
class IndexBuilder(object):
......@@ -30,22 +24,27 @@ class IndexBuilder(object):
assert not (args.load and args.ict_load)
self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
self.load_attributes()
self.is_main_builder = args.rank == 0
self.is_main_builder = mpu.get_data_parallel_rank() == 0
self.num_total_builders = mpu.get_data_parallel_world_size()
self.iteration = self.total_processed = 0
def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
model = get_model(lambda: general_ict_model_provider(only_block_model=True))
self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size))
self.block_data = BlockData(load_from_path=False)
def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress"""
self.iteration += 1
self.total_processed += batch_size
if self.iteration % 10 == 0:
self.total_processed += batch_size * self.num_total_builders
if self.is_main_builder and self.iteration % self.log_interval == 0:
print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)
def build_and_save_index(self):
......@@ -58,17 +57,20 @@ class IndexBuilder(object):
while True:
try:
# batch also has query_tokens and query_pad_data
_, _, block_tokens, block_pad_mask, block_sample_data = get_batch(self.dataloader)
except:
_, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
except StopIteration:
break
# detach, setup and add to BlockData
unwrapped_model = self.model
while not hasattr(unwrapped_model, 'embed_block'):
unwrapped_model = unwrapped_model.module
block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
# detach, separate fields and add to BlockData
block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
detached_data = detach(block_sample_data)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices = detached_data[:, 3]
block_metas = detached_data[:, :3]
......@@ -86,98 +88,3 @@ class IndexBuilder(object):
self.block_data.clear()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, from_realm_chkpt=False):
"""load ICT checkpoints for indexing/retrieving. Arguments specify which parts of the state dict to actually use."""
args = get_args()
model = get_model(lambda: general_ict_model_provider(only_query_model, only_block_model))
if isinstance(model, torchDDP):
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt:
print(">>>> Attempting to get ict state dict from realm", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=True
)
dataset = ICTDataset(**kwargs)
return dataset
def get_one_epoch_dataloader(dataset, batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
# This usage is for basic (as opposed to realm async) indexing jobs.
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
import os
import torch
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from megatron.module import MegatronModule
......@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
......
......@@ -27,33 +27,11 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ICTBertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
num_batches = 0
def general_ict_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
def model_provider():
def pretrain_ict_model_provider():
return general_ict_model_provider(False, False)
......@@ -95,30 +73,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output
def get_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
......@@ -127,7 +81,7 @@ def forward_step(data_iterator, model):
# Get the batch.
timers('batch generator').start()
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop()
......@@ -181,5 +135,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
import sys
sys.path.append('../')
from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron
def main():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
Check README.md for example script
"""
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment