Commit 11f76cd3 authored by Neel Kant's avatar Neel Kant
Browse files

Address comments from last week

parent 7a348580
...@@ -411,7 +411,7 @@ def _add_realm_args(parser): ...@@ -411,7 +411,7 @@ def _add_realm_args(parser):
help='Path to titles dataset used for ICT') help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1, group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset') help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true', group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
# training # training
...@@ -421,7 +421,13 @@ def _add_realm_args(parser): ...@@ -421,7 +421,13 @@ def _add_realm_args(parser):
# faiss index # faiss index
group.add_argument('--faiss-use-gpu', action='store_true', group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU') help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str, group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from') help='Where to save/load BlockData to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
help='How large of batches to use when doing indexing jobs')
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer report progress')
return parser return parser
...@@ -21,9 +21,10 @@ import sys ...@@ -21,9 +21,10 @@ import sys
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu from megatron import mpu, get_args
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -244,3 +245,43 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return iteration return iteration
def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, from_realm_chkpt=False):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
args = get_args()
if isinstance(model, torchDDP):
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
\ No newline at end of file
...@@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -426,7 +426,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob, query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.ict_one_sent, use_one_sent_docs=args.use_one_sent_docs,
**kwargs **kwargs
) )
else: else:
......
...@@ -5,9 +5,36 @@ import numpy as np ...@@ -5,9 +5,36 @@ import numpy as np
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping from megatron.data.realm_dataset_utils import get_block_samples_mapping
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
rather than for training, since it is only built with a single epoch sample mapping.
"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=args.use_one_sent_docs
)
dataset = ICTDataset(**kwargs)
return dataset
class ICTDataset(Dataset): class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
...@@ -35,7 +62,7 @@ class ICTDataset(Dataset): ...@@ -35,7 +62,7 @@ class ICTDataset(Dataset):
self.pad_id = self.tokenizer.pad self.pad_id = self.tokenizer.pad
def __len__(self): def __len__(self):
return self.samples_mapping.shape[0] return len(self.samples_mapping)
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted""" """Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
......
...@@ -6,9 +6,59 @@ import torch ...@@ -6,9 +6,59 @@ import torch
from megatron import mpu, print_rank_0 from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron.data.samplers import DistributedBatchSampler
from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_one_epoch_dataloader(dataset, batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def join_str_list(str_list): def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately""" """Join a list of strings, handling spaces appropriately"""
result = "" result = ""
...@@ -46,10 +96,12 @@ class BlockSamplesMapping(object): ...@@ -46,10 +96,12 @@ class BlockSamplesMapping(object):
# make sure that the array is compatible with BlockSampleData # make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4 assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array self.mapping_array = mapping_array
self.shape = self.mapping_array.shape
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the data associated with a particular sample.""" """Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx]) sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data return sample_data
...@@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -144,6 +196,6 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time)) time.time() - start_time))
print_rank_0(' total number of samples: {}'.format( print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0])) mapping_array.shape[0]))
return samples_mapping return samples_mapping
...@@ -3,7 +3,6 @@ import os ...@@ -3,7 +3,6 @@ import os
import pickle import pickle
import shutil import shutil
import faiss
import numpy as np import numpy as np
import torch import torch
...@@ -123,6 +122,11 @@ class FaissMIPSIndex(object): ...@@ -123,6 +122,11 @@ class FaissMIPSIndex(object):
def _set_block_index(self): def _set_block_index(self):
"""Create a Faiss Flat index with inner product as the metric to search against""" """Create a Faiss Flat index with inner product as the metric to search against"""
try:
import faiss
except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex")
print("\n> Building index", flush=True) print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
...@@ -188,19 +192,18 @@ class FaissMIPSIndex(object): ...@@ -188,19 +192,18 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices if False: return [num_queries x k] array of distances, and another for indices
""" """
query_embeds = np.float32(detach(query_embeds)) query_embeds = np.float32(detach(query_embeds))
with torch.no_grad():
if reconstruct:
if reconstruct: # get the vectors themselves
# get the vectors themselves top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) return top_k_block_embeds
return top_k_block_embeds
else:
else: # get distances and indices of closest vectors
# get distances and indices of closest vectors distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
distances, block_indices = self.block_mips_index.search(query_embeds, top_k) if self.use_gpu:
if self.use_gpu: fresh_indices = np.zeros(block_indices.shape)
fresh_indices = np.zeros(block_indices.shape) for i, j in itertools.product(block_indices.shape):
for i, j in itertools.product(block_indices.shape): fresh_indices[i, j] = self.id_map[block_indices[i, j]]
fresh_indices[i, j] = self.id_map[block_indices[i, j]] block_indices = fresh_indices
block_indices = fresh_indices return distances, block_indices
return distances, block_indices
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import load_ict_checkpoint
from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.ict_dataset import get_ict_dataset
from megatron.data.ict_dataset import ICTDataset from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, BlockData from megatron.data.realm_index import detach, BlockData
from megatron.data.samplers import DistributedBatchSampler from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.initialize import initialize_megatron from megatron.model.realm_model import general_ict_model_provider
from megatron.training import get_model from megatron.training import get_model
from pretrain_ict import get_batch, general_ict_model_provider
def pprint(*args):
print(*args, flush=True)
class IndexBuilder(object): class IndexBuilder(object):
...@@ -30,22 +24,27 @@ class IndexBuilder(object): ...@@ -30,22 +24,27 @@ class IndexBuilder(object):
assert not (args.load and args.ict_load) assert not (args.load and args.ict_load)
self.using_realm_chkpt = args.ict_load is None self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
self.load_attributes() self.load_attributes()
self.is_main_builder = args.rank == 0 self.is_main_builder = mpu.get_data_parallel_rank() == 0
self.num_total_builders = mpu.get_data_parallel_world_size()
self.iteration = self.total_processed = 0 self.iteration = self.total_processed = 0
def load_attributes(self): def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData""" """Load the necessary attributes: model, dataloader and empty BlockData"""
self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt) model = get_model(lambda: general_ict_model_provider(only_block_model=True))
self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
self.model.eval() self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset())) self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size))
self.block_data = BlockData(load_from_path=False) self.block_data = BlockData(load_from_path=False)
def track_and_report_progress(self, batch_size): def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress""" """Utility function for tracking progress"""
self.iteration += 1 self.iteration += 1
self.total_processed += batch_size self.total_processed += batch_size * self.num_total_builders
if self.iteration % 10 == 0: if self.is_main_builder and self.iteration % self.log_interval == 0:
print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True) print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)
def build_and_save_index(self): def build_and_save_index(self):
...@@ -58,17 +57,20 @@ class IndexBuilder(object): ...@@ -58,17 +57,20 @@ class IndexBuilder(object):
while True: while True:
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
_, _, block_tokens, block_pad_mask, block_sample_data = get_batch(self.dataloader) _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
except: except StopIteration:
break break
# detach, setup and add to BlockData
unwrapped_model = self.model unwrapped_model = self.model
while not hasattr(unwrapped_model, 'embed_block'): while not hasattr(unwrapped_model, 'embed_block'):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
# detach, separate fields and add to BlockData
block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
detached_data = detach(block_sample_data) detached_data = detach(block_sample_data)
# block_sample_data is a 2D array [batch x 4]
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
block_indices = detached_data[:, 3] block_indices = detached_data[:, 3]
block_metas = detached_data[:, :3] block_metas = detached_data[:, :3]
...@@ -86,98 +88,3 @@ class IndexBuilder(object): ...@@ -86,98 +88,3 @@ class IndexBuilder(object):
self.block_data.clear() self.block_data.clear()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, from_realm_chkpt=False):
"""load ICT checkpoints for indexing/retrieving. Arguments specify which parts of the state dict to actually use."""
args = get_args()
model = get_model(lambda: general_ict_model_provider(only_query_model, only_block_model))
if isinstance(model, torchDDP):
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt:
print(">>>> Attempting to get ict state dict from realm", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data"""
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles,
use_one_sent_docs=True
)
dataset = ICTDataset(**kwargs)
return dataset
def get_one_epoch_dataloader(dataset, batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=False,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
# This usage is for basic (as opposed to realm async) indexing jobs.
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
import os import os
import torch import torch
from megatron import get_args from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel from megatron.model import BertModel
from megatron.module import MegatronModule from megatron.module import MegatronModule
...@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal ...@@ -13,6 +13,28 @@ from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task.""" """Bert-based module for Inverse Cloze task."""
def __init__(self, def __init__(self,
......
...@@ -27,33 +27,11 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets ...@@ -27,33 +27,11 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ICTBertModel from megatron.model import ICTBertModel
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
num_batches = 0
def pretrain_ict_model_provider():
def general_ict_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
def model_provider():
return general_ict_model_provider(False, False) return general_ict_model_provider(False, False)
...@@ -95,30 +73,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -95,30 +73,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
return output return output
def get_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
...@@ -127,7 +81,7 @@ def forward_step(data_iterator, model): ...@@ -127,7 +81,7 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
...@@ -181,5 +135,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -181,5 +135,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step, pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
import sys
sys.path.append('../')
from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron
def main():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
Check README.md for example script
"""
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment