Commit 2eaf6c79 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

cleaning the code

parent 7a0710ec
...@@ -26,8 +26,8 @@ class IndexBuilder(object): ...@@ -26,8 +26,8 @@ class IndexBuilder(object):
self.evidence_embedder_obj = None self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model args.biencoder_shared_query_context_model
self.pre_process = True #self.pre_process = True
self.post_process = True #self.post_process = True
# need to know whether we're using a REALM checkpoint (args.load) # need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint # or ICT checkpoint
...@@ -46,7 +46,7 @@ class IndexBuilder(object): ...@@ -46,7 +46,7 @@ class IndexBuilder(object):
""" """
Load the necessary attributes: model, dataloader and empty BlockData Load the necessary attributes: model, dataloader and empty BlockData
""" """
args = get_args() #args = get_args()
only_context_model = True only_context_model = True
if self.biencoder_shared_query_context_model: if self.biencoder_shared_query_context_model:
only_context_model = False only_context_model = False
...@@ -103,12 +103,12 @@ class IndexBuilder(object): ...@@ -103,12 +103,12 @@ class IndexBuilder(object):
while not hasattr(unwrapped_model, 'embed_text'): while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
counter = 0 #counter = 0
start_time = time.time() #start_time = time.time()
cur_time = start_time #cur_time = start_time
while True: while True:
#start_time = time.time() #start_time = time.time()
t1 = time.time() #t1 = time.time()
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
row_id, context_tokens, context_mask, context_types, \ row_id, context_tokens, context_mask, context_types, \
...@@ -118,7 +118,7 @@ class IndexBuilder(object): ...@@ -118,7 +118,7 @@ class IndexBuilder(object):
break break
#print_rank_0("get batch time {}".format(cur_time - time.time())) #print_rank_0("get batch time {}".format(cur_time - time.time()))
t2 = time.time() #t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage # TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData # detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool assert context_mask.dtype == torch.bool
...@@ -129,17 +129,17 @@ class IndexBuilder(object): ...@@ -129,17 +129,17 @@ class IndexBuilder(object):
context_logits = detach(context_logits) context_logits = detach(context_logits)
row_id = detach(row_id) row_id = detach(row_id)
#print_rank_0("embed text {}".format(cur_time - time.time())) #print_rank_0("embed text {}".format(cur_time - time.time()))
t3 = time.time() #t3 = time.time()
self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id)) self.track_and_report_progress(batch_size=len(row_id))
#print_rank_0("add block time {}".format(cur_time - time.time())) #print_rank_0("add block time {}".format(cur_time - time.time()))
t4 = time.time() #t4 = time.time()
counter += 1 #counter += 1
if counter % 1000 == 0: #if counter % 1000 == 0:
print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) # print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) # print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
cur_time = time.time() # cur_time = time.time()
# This process signals to finalize its shard and then synchronize with # This process signals to finalize its shard and then synchronize with
# the other processes # the other processes
self.evidence_embedder_obj.save_shard() self.evidence_embedder_obj.save_shard()
......
...@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False, ...@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False,
return model_provider return model_provider
#def biencoder_model_provider(pre_process=True,
# post_process=True):
def biencoder_model_provider(only_query_model=False, def biencoder_model_provider(only_query_model=False,
only_context_model=False, only_context_model=False,
biencoder_shared_query_context_model=False, biencoder_shared_query_context_model=False,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
"""Build the model.""" """Build the model."""
#args = get_args()
#biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
#only_context_model = args.only_context_model
#only_query_model = args.only_query_model
assert mpu.get_tensor_model_parallel_world_size() == 1 and \ assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \ mpu.get_pipeline_model_parallel_world_size() == 1, \
...@@ -114,9 +105,9 @@ class BiEncoderModel(MegatronModule): ...@@ -114,9 +105,9 @@ class BiEncoderModel(MegatronModule):
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
#this is just a placeholder and will be needed when model # this is just a placeholder and will be needed when model
#parallelism will be used # parallelism will be used
#self.language_model.set_input_tensor(input_tensor) # self.language_model.set_input_tensor(input_tensor)
return return
def forward(self, query_tokens, query_attention_mask, query_types, def forward(self, query_tokens, query_attention_mask, query_types,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from megatron.model.enums import LayerType, AttnMaskType
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"""Finetune utilities.""" """Finetune utilities."""
from functools import partial from functools import partial
import sys
import torch import torch
...@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model, valid_dataloader, model,
iteration, False) iteration, False)
#if iteration == 600:
# sys.exit()
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
...@@ -15,18 +15,6 @@ ...@@ -15,18 +15,6 @@
"""Main tasks functionality.""" """Main tasks functionality."""
import os
import sys
#sys.path.append(
# os.path.abspath(
# os.path.join(
# os.path.join(os.path.dirname(__file__), os.path.pardir),
# os.path.pardir,
# )
# )
#)
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
...@@ -35,30 +23,23 @@ def main(): ...@@ -35,30 +23,23 @@ def main():
""" """
Main program Main program
""" """
#initialize_megatron(extra_args_provider=None,
# args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset """
- Include all args needed for initial model specification Create a BlockData data structure by running an IndexBuilder over an
ICT Dataset and then evaluate on NQ task
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
""" """
#print_rank_0("Starting index builder!") print_rank_0("Starting index builder!")
index_builder = IndexBuilder() index_builder = IndexBuilder()
index_builder.build_and_save_index() index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!") print_rank_0("Build and save indices: done!")
print_rank_0("Starting evaluations!")
# Set up the model and evaluator # Set up the model and evaluator
evaluator = ORQAEvaluator() evaluator = ORQAEvaluator()
......
...@@ -50,7 +50,6 @@ class ORQAEvaluator(object): ...@@ -50,7 +50,6 @@ class ORQAEvaluator(object):
model = get_model(get_model_provider(only_query_model=only_query_model, model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\ # only_query_model, biencoder_shared_query_context_model=\
......
...@@ -295,5 +295,3 @@ class NQSupervisedDataset(OpenRetrievalAbstractDataset): ...@@ -295,5 +295,3 @@ class NQSupervisedDataset(OpenRetrievalAbstractDataset):
print_rank_0(' >> processed {} samples.'.format(len(samples))) print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples return samples
...@@ -34,7 +34,6 @@ def task_collate_fn(batch_data): ...@@ -34,7 +34,6 @@ def task_collate_fn(batch_data):
for d in batch_data: for d in batch_data:
for k, v in d.items(): for k, v in d.items():
tensorized.setdefault(k, []).append(v) tensorized.setdefault(k, []).append(v)
# assert len(tensorized) == 12
tensorized['query'] = torch.LongTensor(tensorized['query']) tensorized['query'] = torch.LongTensor(tensorized['query'])
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask']) tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
...@@ -90,8 +89,6 @@ def process_batch(batch): ...@@ -90,8 +89,6 @@ def process_batch(batch):
neg_context_tokens, neg_context_mask, neg_context_types, reference neg_context_tokens, neg_context_mask, neg_context_types, reference
def accuracy_func_provider(single_dataset_provider, rank0sampler=False): def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
#, datapath,
# rank0sampler=False):
"""Provide function that calculates accuracies.""" """Provide function that calculates accuracies."""
args = get_args() args = get_args()
...@@ -113,8 +110,6 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False): ...@@ -113,8 +110,6 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
num_workers=args.num_workers, num_workers=args.num_workers,
drop_last=drop_last, drop_last=drop_last,
task_collate_fn=task_collate_fn) task_collate_fn=task_collate_fn)
#shuffle=False,
#rank0sampler=rank0sampler)
dataloaders = (dataset.dataset_name, dataloader) dataloaders = (dataset.dataset_name, dataloader)
def metrics_func(model, epoch, output_predictions=False): def metrics_func(model, epoch, output_predictions=False):
......
...@@ -22,27 +22,21 @@ import math ...@@ -22,27 +22,21 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args, get_timers, get_tokenizer
from megatron import get_timers from megatron import mpu, print_rank_0
from megatron import get_tokenizer from megatron.indexer import IndexBuilder
from megatron import mpu
from megatron import print_rank_0
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider from megatron.utils import average_losses_across_data_parallel_group
#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune
from pretrain_ict import get_group_world_size_rank from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
from megatron.indexer import IndexBuilder
def orqa(Dataset): # , name_from_datapath_func): def orqa(Dataset):
def cross_entropy_forward_step(batch, model): def cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss.""" """Simple forward step with cross-entropy loss."""
args = get_args()
timers = get_timers() timers = get_timers()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -73,16 +67,14 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -73,16 +67,14 @@ def orqa(Dataset): # , name_from_datapath_func):
context_types = torch.cat([context_types, neg_context_types]) context_types = torch.cat([context_types, neg_context_types])
# Forward model. # Forward model.
#query_logits, context_logits = model(query_tokens, query_mask,
output_tensor = model(query_tokens, query_mask, output_tensor = model(query_tokens, query_mask,
query_types, context_tokens, query_types, context_tokens,
context_mask, context_types) context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func_, query_tokens, context_tokens) return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
#def cross_entropy_loss_func(labels, output_tensor): def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
def cross_entropy_loss_func_(query_tokens, context_tokens, output_tensor):
args = get_args() args = get_args()
local_batch_size = query_tokens.shape[0] local_batch_size = query_tokens.shape[0]
...@@ -184,9 +176,6 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -184,9 +176,6 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task)) print_rank_0('building retriever model for {} ...'.format(args.task))
#args.only_context_model=False
#args.only_query_model=False
#model = biencoder_model_provider()
model = biencoder_model_provider(only_context_model=False, model = biencoder_model_provider(only_context_model=False,
only_query_model=False, only_query_model=False,
...@@ -200,7 +189,6 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -200,7 +189,6 @@ def orqa(Dataset): # , name_from_datapath_func):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
#name = name_from_datapath_func(datapath)
name = datapath[0].split('/')[-1].split('.')[0] name = datapath[0].split('/')[-1].split('.')[0]
return Dataset(name, return Dataset(name,
datapath, datapath,
...@@ -208,41 +196,25 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -208,41 +196,25 @@ def orqa(Dataset): # , name_from_datapath_func):
args.retriever_seq_length, args.retriever_seq_length,
evaluate=True) evaluate=True)
#def distributed_metrics_func_provider():
def metrics_func_provider(): def metrics_func_provider():
"""Provide metrics callback function.""" """Provide metrics callback function."""
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
return accuracy_func_provider(single_dataset_provider) return accuracy_func_provider(single_dataset_provider)
#def rank0_metrics_func_provider(datapath):
# """Provide metrics callback function."""
# return accuracy_func_provider(single_dataset_provider, datapath,
# rank0sampler=True)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune(train_valid_datasets_provider, finetune(train_valid_datasets_provider,
model_provider, model_provider,
forward_step=cross_entropy_forward_step, forward_step=cross_entropy_forward_step,
end_of_epoch_callback_provider=metrics_func_provider, end_of_epoch_callback_provider=metrics_func_provider,
task_collate_fn=task_collate_fn) task_collate_fn=task_collate_fn)
#,end_of_training_callback_provider=rank0_metrics_func_provider)
def main(): def main():
args = get_args() args = get_args()
if args.task == 'RET-FINETUNE-NQ': if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else: else:
raise NotImplementedError('ORQA task {} is not implemented.'.format( raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task)) args.task))
orqa(Dataset) #, name_from_datapath) orqa(Dataset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment