Commit 220637f9 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

DPR evaluation debugging

parent a8d172b3
......@@ -478,6 +478,12 @@ def _add_learning_rate_args(parser):
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--override-lr-new', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
......
......@@ -413,8 +413,11 @@ def load_biencoder_checkpoint(model, only_query_model=False,
if only_context_model:
ret_state_dict.pop('query_model')
assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
#print_rank_0(len(model))
#sys.exit()
#assert len(model) == 1
#model[0].load_state_dict(ret_state_dict)
model.load_state_dict(ret_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
......
......@@ -2,7 +2,7 @@ import sys
import torch
import torch.distributed as dist
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
......@@ -25,6 +25,8 @@ class IndexBuilder(object):
self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model
self.pre_process = True
self.post_process = True
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
......@@ -47,15 +49,22 @@ class IndexBuilder(object):
if self.biencoder_shared_query_context_model:
only_context_model = False
model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model, \
# pre_process=self.pre_process, post_process=self.post_process))
model = biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
self.biencoder_shared_query_context_model))
self.biencoder_shared_query_context_model, \
pre_process=self.pre_process, post_process=self.post_process)
self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model)
assert len(self.model) == 1
self.model[0].eval()
#assert len(self.model) == 1
#self.model[0].eval()
self.model.eval()
self.dataset = get_open_retrieval_wiki_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
......@@ -83,10 +92,12 @@ class IndexBuilder(object):
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert len(self.model) == 1
unwrapped_model = self.model[0]
#assert len(self.model) == 1
#unwrapped_model = self.model[0]
unwrapped_model = self.model
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
print_rank_0("hasattr")
while True:
try:
......@@ -97,12 +108,26 @@ class IndexBuilder(object):
except (StopIteration, IndexError):
break
print_rank_0(context_tokens)
print_rank_0(context_mask)
print_rank_0(context_types)
#if torch.cuda.is_available():
# print_rank_0("cuda available")
#print_rank_0(torch.cuda.current_device())
#print_rank_0(torch.cuda.get_device_name())
print_rank_0(next(unwrapped_model.parameters()).device)
print_rank_0(next(unwrapped_model.context_model.parameters()).device)
#print_rank_0("After get_open_retrieval_batch")
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool
context_logits = unwrapped_model.embed_text(
unwrapped_model.context_model, context_tokens, context_mask,
context_types)
sys.exit()
context_logits = detach(context_logits)
row_id = detach(row_id)
......
......@@ -18,6 +18,7 @@
import math
from megatron import print_rank_0
from megatron import get_args
class AnnealingLR(object):
"""Anneals the learning rate."""
......@@ -59,6 +60,7 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr))
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \
......@@ -87,7 +89,21 @@ class AnnealingLR(object):
else:
raise Exception('{} decay style is not supported.'.format(
self.decay_style))
args = get_args()
if args.override_lr_new:
mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps)
mod_num_steps_ = mod_num_steps_ - self.warmup_steps
use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps)
should_use_lr = self.min_lr + coeff * delta_lr
print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
else:
use_lr = self.min_lr + coeff * delta_lr
print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
return use_lr
return self.min_lr + coeff * delta_lr
......
......@@ -266,6 +266,10 @@ class PretrainedBertModel(MegatronModule):
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
print_rank_0(input_ids.device)
print_rank_0(position_ids.device)
print_rank_0(extended_attention_mask.device)
print_rank_0(tokentype_ids.device)
lm_output = self.language_model(input_ids,
position_ids,
......
......@@ -18,7 +18,7 @@
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import mpu
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
......@@ -338,6 +338,11 @@ class TransformerLanguageModel(MegatronModule):
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
print_rank_0("before self.embedding")
print_rank_0(enc_input_ids.device)
print_rank_0(enc_position_ids.device)
print_rank_0(tokentype_ids.device)
# Embeddings.
if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids,
......
......@@ -16,6 +16,7 @@
"""Finetune utilities."""
from functools import partial
import sys
import torch
......@@ -225,6 +226,9 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model,
iteration, False)
#if iteration == 600:
# sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
......@@ -34,6 +34,8 @@ from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator
from megatron.indexer import IndexBuilder
def orqa(Dataset): # , name_from_datapath_func):
......@@ -226,14 +228,29 @@ def orqa(Dataset): # , name_from_datapath_func):
def main():
args = get_args()
if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#if args.task == 'RET-FINETUNE-NQ':
# from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else:
raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task))
#else:
# raise NotImplementedError('ORQA task {} is not implemented.'.format(
# args.task))
#orqa(Dataset) #, name_from_datapath)
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
# Set up the model and evaluator
#evaluator = ORQAEvaluator()
# Run evaluation
#if args.qa_data_dev is not None:
# evaluator.evaluate(args.qa_data_dev, "DEV")
#if args.qa_data_test is not None:
# evaluator.evaluate(args.qa_data_test, "TEST")
orqa(Dataset) #, name_from_datapath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment