Commit 25293807 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

additional cleaning

parent 2eaf6c79
...@@ -479,12 +479,6 @@ def _add_learning_rate_args(parser): ...@@ -479,12 +479,6 @@ def _add_learning_rate_args(parser):
group.add_argument('--min-lr', type=float, default=0.0, group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler' help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.') 'clip values below this threshold.')
group.add_argument('--override-lr-new', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--override-lr-scheduler', action='store_true', group.add_argument('--override-lr-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,' help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum ' 'warmup iterations, minimum learning rate, maximum '
......
...@@ -419,7 +419,6 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -419,7 +419,6 @@ def load_biencoder_checkpoint(model, only_query_model=False,
assert len(model) == 1 assert len(model) == 1
model[0].load_state_dict(ret_state_dict) model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
......
...@@ -26,13 +26,10 @@ class IndexBuilder(object): ...@@ -26,13 +26,10 @@ class IndexBuilder(object):
self.evidence_embedder_obj = None self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model args.biencoder_shared_query_context_model
#self.pre_process = True
#self.post_process = True
# need to know whether we're using a REALM checkpoint (args.load) # need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint # or ICT checkpoint
assert not (args.load and args.ict_load) assert not (args.load and args.ict_load)
#self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size self.batch_size = args.indexer_batch_size
...@@ -46,24 +43,13 @@ class IndexBuilder(object): ...@@ -46,24 +43,13 @@ class IndexBuilder(object):
""" """
Load the necessary attributes: model, dataloader and empty BlockData Load the necessary attributes: model, dataloader and empty BlockData
""" """
#args = get_args()
only_context_model = True only_context_model = True
if self.biencoder_shared_query_context_model: if self.biencoder_shared_query_context_model:
only_context_model = False only_context_model = False
#args.only_context_model = only_context_model model = get_model(get_model_provider(only_context_model=\
#args.only_query_model = False only_context_model, biencoder_shared_query_context_model=\
self.biencoder_shared_query_context_model))
#model = get_model(biencoder_model_provider)
model = get_model(get_model_provider(only_context_model=only_context_model,
biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model,
# pre_process=True, post_process=True)
self.model = load_biencoder_checkpoint(model, self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model) only_context_model=only_context_model)
...@@ -103,12 +89,7 @@ class IndexBuilder(object): ...@@ -103,12 +89,7 @@ class IndexBuilder(object):
while not hasattr(unwrapped_model, 'embed_text'): while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
#counter = 0
#start_time = time.time()
#cur_time = start_time
while True: while True:
#start_time = time.time()
#t1 = time.time()
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
row_id, context_tokens, context_mask, context_types, \ row_id, context_tokens, context_mask, context_types, \
...@@ -117,8 +98,6 @@ class IndexBuilder(object): ...@@ -117,8 +98,6 @@ class IndexBuilder(object):
except (StopIteration, IndexError): except (StopIteration, IndexError):
break break
#print_rank_0("get batch time {}".format(cur_time - time.time()))
#t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage # TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData # detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool assert context_mask.dtype == torch.bool
...@@ -128,18 +107,10 @@ class IndexBuilder(object): ...@@ -128,18 +107,10 @@ class IndexBuilder(object):
context_logits = detach(context_logits) context_logits = detach(context_logits)
row_id = detach(row_id) row_id = detach(row_id)
#print_rank_0("embed text {}".format(cur_time - time.time()))
#t3 = time.time()
self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id)) self.track_and_report_progress(batch_size=len(row_id))
#print_rank_0("add block time {}".format(cur_time - time.time()))
#t4 = time.time()
#counter += 1
#if counter % 1000 == 0:
# print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
# print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
# cur_time = time.time()
# This process signals to finalize its shard and then synchronize with # This process signals to finalize its shard and then synchronize with
# the other processes # the other processes
self.evidence_embedder_obj.save_shard() self.evidence_embedder_obj.save_shard()
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import math import math
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_args
class AnnealingLR(object): class AnnealingLR(object):
"""Anneals the learning rate.""" """Anneals the learning rate."""
...@@ -60,7 +59,6 @@ class AnnealingLR(object): ...@@ -60,7 +59,6 @@ class AnnealingLR(object):
"""Learning rate decay functions from: """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr))
# Use linear warmup for the initial part. # Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \ return self.max_lr * float(self.num_steps) / \
...@@ -90,20 +88,6 @@ class AnnealingLR(object): ...@@ -90,20 +88,6 @@ class AnnealingLR(object):
raise Exception('{} decay style is not supported.'.format( raise Exception('{} decay style is not supported.'.format(
self.decay_style)) self.decay_style))
args = get_args()
if args.override_lr_new:
mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps)
mod_num_steps_ = mod_num_steps_ - self.warmup_steps
use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps)
should_use_lr = self.min_lr + coeff * delta_lr
print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
else:
use_lr = self.min_lr + coeff * delta_lr
print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
return use_lr
return self.min_lr + coeff * delta_lr return self.min_lr + coeff * delta_lr
......
...@@ -114,7 +114,8 @@ def _build_infinite_size_dataloader(dataloader): ...@@ -114,7 +114,8 @@ def _build_infinite_size_dataloader(dataloader):
iterator = dataloader.__iter__() iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None): def _build_train_valid_dataloaders(train_dataset, valid_dataset,
task_collate_fn=None):
"""Traing and validation dataloaders.""" """Traing and validation dataloaders."""
args = get_args() args = get_args()
......
...@@ -44,20 +44,9 @@ class ORQAEvaluator(object): ...@@ -44,20 +44,9 @@ class ORQAEvaluator(object):
if args.biencoder_shared_query_context_model: if args.biencoder_shared_query_context_model:
only_query_model = False only_query_model = False
#args.only_query_model = only_query_model
#args.only_context_model = False
model = get_model(get_model_provider(only_query_model=only_query_model, model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=True, post_process=True))
#model = get_model(biencoder_model_provider)
self.model = load_biencoder_checkpoint(model, self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model) only_query_model=only_query_model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment