Commit 2eaf6c79 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

cleaning the code

parent 7a0710ec
......@@ -26,8 +26,8 @@ class IndexBuilder(object):
self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model
self.pre_process = True
self.post_process = True
#self.pre_process = True
#self.post_process = True
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
......@@ -46,7 +46,7 @@ class IndexBuilder(object):
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
args = get_args()
#args = get_args()
only_context_model = True
if self.biencoder_shared_query_context_model:
only_context_model = False
......@@ -56,7 +56,7 @@ class IndexBuilder(object):
#model = get_model(biencoder_model_provider)
model = get_model(get_model_provider(only_context_model=only_context_model,
model = get_model(get_model_provider(only_context_model=only_context_model,
biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_context_model \
......@@ -103,12 +103,12 @@ class IndexBuilder(object):
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
counter = 0
start_time = time.time()
cur_time = start_time
#counter = 0
#start_time = time.time()
#cur_time = start_time
while True:
#start_time = time.time()
t1 = time.time()
#t1 = time.time()
try:
# batch also has query_tokens and query_pad_data
row_id, context_tokens, context_mask, context_types, \
......@@ -118,7 +118,7 @@ class IndexBuilder(object):
break
#print_rank_0("get batch time {}".format(cur_time - time.time()))
t2 = time.time()
#t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool
......@@ -129,17 +129,17 @@ class IndexBuilder(object):
context_logits = detach(context_logits)
row_id = detach(row_id)
#print_rank_0("embed text {}".format(cur_time - time.time()))
t3 = time.time()
#t3 = time.time()
self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id))
#print_rank_0("add block time {}".format(cur_time - time.time()))
t4 = time.time()
counter += 1
if counter % 1000 == 0:
print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
cur_time = time.time()
#t4 = time.time()
#counter += 1
#if counter % 1000 == 0:
# print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
# print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
# cur_time = time.time()
# This process signals to finalize its shard and then synchronize with
# the other processes
self.evidence_embedder_obj.save_shard()
......
......@@ -15,17 +15,17 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def get_model_provider(only_query_model=False, only_context_model=False,
def get_model_provider(only_query_model=False, only_context_model=False,
biencoder_shared_query_context_model=False):
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building Bienoder model ...')
model = biencoder_model_provider(only_query_model=only_query_model,
only_context_model = only_context_model,
model = biencoder_model_provider(only_query_model=only_query_model,
only_context_model = only_context_model,
biencoder_shared_query_context_model = \
biencoder_shared_query_context_model,
biencoder_shared_query_context_model,
pre_process=True, post_process=True)
return model
......@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False,
return model_provider
#def biencoder_model_provider(pre_process=True,
# post_process=True):
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
"""Build the model."""
#args = get_args()
#biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
#only_context_model = args.only_context_model
#only_query_model = args.only_query_model
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
......@@ -63,7 +54,7 @@ def biencoder_model_provider(only_query_model=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
biencoder_shared_query_context_model=\
biencoder_shared_query_context_model,
biencoder_shared_query_context_model,
pre_process=pre_process,
post_process=post_process)
......@@ -114,9 +105,9 @@ class BiEncoderModel(MegatronModule):
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
#this is just a placeholder and will be needed when model
#parallelism will be used
#self.language_model.set_input_tensor(input_tensor)
# this is just a placeholder and will be needed when model
# parallelism will be used
# self.language_model.set_input_tensor(input_tensor)
return
def forward(self, query_tokens, query_attention_mask, query_types,
......
......@@ -18,7 +18,7 @@
import torch
import torch.nn.functional as F
from megatron import get_args, print_rank_0
from megatron import get_args
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
......
......@@ -36,7 +36,7 @@ def pretrain_ict_model_provider():
#args.only_context_model = False
#args.only_query_model = False
#model = biencoder_model_provider()
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
......
......@@ -16,7 +16,6 @@
"""Finetune utilities."""
from functools import partial
import sys
import torch
......@@ -81,7 +80,7 @@ def _cross_entropy_forward_step(batch, model):
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
......@@ -190,7 +189,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
......@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model,
iteration, False)
#if iteration == 600:
# sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
......@@ -89,8 +89,8 @@ def get_tasks_args(parser):
# help="Av.rank validation: batch size to process passages")
#group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
# help="Av.rank validation: max num of questions")
return parser
......
......@@ -15,18 +15,6 @@
"""Main tasks functionality."""
import os
import sys
#sys.path.append(
# os.path.abspath(
# os.path.join(
# os.path.join(os.path.dirname(__file__), os.path.pardir),
# os.path.pardir,
# )
# )
#)
from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
......@@ -35,30 +23,23 @@ def main():
"""
Main program
"""
#initialize_megatron(extra_args_provider=None,
# args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
Create a BlockData data structure by running an IndexBuilder over an
ICT Dataset and then evaluate on NQ task
"""
#print_rank_0("Starting index builder!")
print_rank_0("Starting index builder!")
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
print_rank_0("Starting evaluations!")
# Set up the model and evaluator
evaluator = ORQAEvaluator()
......@@ -68,4 +49,4 @@ def main():
if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST")
......@@ -47,10 +47,9 @@ class ORQAEvaluator(object):
#args.only_query_model = only_query_model
#args.only_context_model = False
model = get_model(get_model_provider(only_query_model=only_query_model,
model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
......
......@@ -104,9 +104,9 @@ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
return enc_ids, tokentypes_enc, pad_mask
def build_sample(query_ids, query_types, query_pad_mask,
def build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask, answers,
neg_ctx_id_list=None, neg_ctx_types_list=None,
neg_ctx_id_list=None, neg_ctx_types_list=None,
include_neg=False):
"""Convert to numpy and return a sample consumed by the batch producer."""
......@@ -295,5 +295,3 @@ class NQSupervisedDataset(OpenRetrievalAbstractDataset):
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
......@@ -34,7 +34,6 @@ def task_collate_fn(batch_data):
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
# assert len(tensorized) == 12
tensorized['query'] = torch.LongTensor(tensorized['query'])
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
......@@ -90,8 +89,6 @@ def process_batch(batch):
neg_context_tokens, neg_context_mask, neg_context_types, reference
def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
#, datapath,
# rank0sampler=False):
"""Provide function that calculates accuracies."""
args = get_args()
......@@ -112,9 +109,7 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
args.eval_micro_batch_size,
num_workers=args.num_workers,
drop_last=drop_last,
task_collate_fn=task_collate_fn)
#shuffle=False,
#rank0sampler=rank0sampler)
task_collate_fn=task_collate_fn)
dataloaders = (dataset.dataset_name, dataloader)
def metrics_func(model, epoch, output_predictions=False):
......@@ -197,7 +192,7 @@ def retrieval_loss(model, dataloader):
losses = average_losses_across_data_parallel_group([rank, \
*topk_accs])
# create stats_dict with retrieval loss and all specified
# create stats_dict with retrieval loss and all specified
# top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, losses[1:])}
......
......@@ -22,27 +22,21 @@ import math
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.utils import average_losses_across_data_parallel_group
from megatron import get_args, get_timers, get_tokenizer
from megatron import mpu, print_rank_0
from megatron.indexer import IndexBuilder
from megatron.model.biencoder_model import biencoder_model_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune
from megatron.utils import average_losses_across_data_parallel_group
from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator
from megatron.indexer import IndexBuilder
def orqa(Dataset): # , name_from_datapath_func):
def orqa(Dataset):
def cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
args = get_args()
timers = get_timers()
tokenizer = get_tokenizer()
......@@ -73,17 +67,15 @@ def orqa(Dataset): # , name_from_datapath_func):
context_types = torch.cat([context_types, neg_context_types])
# Forward model.
#query_logits, context_logits = model(query_tokens, query_mask,
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func_, query_tokens, context_tokens)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
#def cross_entropy_loss_func(labels, output_tensor):
def cross_entropy_loss_func_(query_tokens, context_tokens, output_tensor):
args = get_args()
def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
args = get_args()
local_batch_size = query_tokens.shape[0]
group, rank, world_size = get_group_world_size_rank()
......@@ -184,12 +176,9 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model."""
args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task))
#args.only_context_model=False
#args.only_query_model=False
#model = biencoder_model_provider()
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
......@@ -200,7 +189,6 @@ def orqa(Dataset): # , name_from_datapath_func):
args = get_args()
tokenizer = get_tokenizer()
#name = name_from_datapath_func(datapath)
name = datapath[0].split('/')[-1].split('.')[0]
return Dataset(name,
datapath,
......@@ -208,41 +196,25 @@ def orqa(Dataset): # , name_from_datapath_func):
args.retriever_seq_length,
evaluate=True)
#def distributed_metrics_func_provider():
def metrics_func_provider():
"""Provide metrics callback function."""
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
return accuracy_func_provider(single_dataset_provider)
#def rank0_metrics_func_provider(datapath):
# """Provide metrics callback function."""
# return accuracy_func_provider(single_dataset_provider, datapath,
# rank0sampler=True)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider,
model_provider,
forward_step=cross_entropy_forward_step,
end_of_epoch_callback_provider=metrics_func_provider,
task_collate_fn=task_collate_fn)
#,end_of_training_callback_provider=rank0_metrics_func_provider)
def main():
args = get_args()
if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else:
raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task))
orqa(Dataset) #, name_from_datapath)
orqa(Dataset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment