"tests/vscode:/vscode.git/clone" did not exist on "12a53b43833b7bea279a205e313f2bd3f0cdfd99"
Commit e919dd8e authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

cleared the commented codes

parent bfc20ecf
......@@ -606,8 +606,8 @@ def _add_biencoder_args(parser):
# faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
#group.add_argument('--block-data-path', type=str, default=None,
# help='Where to save/load BlockData to/from')
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
......
......@@ -59,12 +59,6 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0(self.warmup_steps)
#print_rank_0(self.num_steps)
#print_rank_0(self.warmup_steps)
#print_rank_0(self.max_lr)
#print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps))
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \
......@@ -103,7 +97,7 @@ class AnnealingLR(object):
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
#print_rank_0(new_lr)
def state_dict(self):
state_dict = {
......
......@@ -27,7 +27,7 @@ def biencoder_model_provider(only_query_model=False,
print_rank_0('building BiEncoderModel...')
# simpler to just keep using 2 tokentypes since
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model = BiEncoderModel(
num_tokentypes=2,
......@@ -78,7 +78,7 @@ class BiEncoderModel(MegatronModule):
def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and
"""Run a forward pass for each of the models and
return the respective embeddings."""
if self.use_query_model:
......@@ -145,7 +145,7 @@ class BiEncoderModel(MegatronModule):
state_dict[self._context_key], strict=strict)
def init_state_dict_from_bert(self):
"""Initialize the state from a pretrained BERT model
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args = get_args()
......@@ -160,11 +160,6 @@ class BiEncoderModel(MegatronModule):
iteration = int(f.read().strip())
assert iteration > 0
#for param in self.query_model.language_model.parameters():
# print(param.data)
#break
#sys.exit()
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading BERT checkpoint {}'.format(
......@@ -193,17 +188,13 @@ class BiEncoderModel(MegatronModule):
if self.query_model is not None and self.projection_dim > 0:
self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict)
#for param in self.query_model.language_model.parameters():
# print(param.data)
# #sys.exit()
class PretrainedBertModel(MegatronModule):
"""BERT-based encoder for queries or contexts used for
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def __init__(self, num_tokentypes=2,
def __init__(self, num_tokentypes=2,
parallel_output=True):
super(PretrainedBertModel, self).__init__()
......@@ -242,7 +233,7 @@ class PretrainedBertModel(MegatronModule):
tokentype_ids=tokentype_ids)
# This mask will be used in average-pooling and max-pooling
pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT
if self.pool_type == "cls-token":
pooled_output = lm_output[:, 0, :]
......@@ -256,7 +247,7 @@ class PretrainedBertModel(MegatronModule):
# Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype)
# Output.
if self.projection_dim:
pooled_output = self.projection_enc(pooled_output)
......
......@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory, params_grad_norm, params_global_norm, print_model, print_grads
from megatron.utils import report_memory
def print_datetime(string):
......@@ -648,7 +648,6 @@ def train_step(forward_step_func, data_iterator,
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()
grad_norm_local = None
# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
......@@ -663,30 +662,14 @@ def train_step(forward_step_func, data_iterator,
mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names)
else:
grad_norm_local = optimizer.clip_master_grads(args.clip_grad)
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
#print_rank_0("print-grad_norm_local {}".format(grad_norm_local))
#print_rank_0("after backward")
#print_grads(model)
#print_model(model)
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
#print_rank_0("after optimizer")
#print_model(model)
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
#sys.exit()
#print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow))
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
......@@ -861,10 +844,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
#print_rank_0("Check betas before iterations")
#for group in optimizer.optimizer.param_groups:
# print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps']))
timers('interval time').start()
print_datetime('before the start of training step')
report_memory_flag = True
......
......@@ -150,40 +150,4 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def params_grad_norm(model):
print_rank_0("params_grad_norm")
norm2 = torch.cuda.FloatTensor([0.0])
for param in model.parameters():
if param.grad is None:
continue
norm = torch.norm(param.grad.data.float(), 2)
norm2 += norm * norm
torch.distributed.all_reduce(norm2)
norm = norm2 ** 0.5
return norm.item()
def params_global_norm(model):
print_rank_0("params_global_norm")
norm2 = torch.cuda.FloatTensor([0.0])
for param in model.parameters():
norm = torch.norm(param.data.float(), 2)
norm2 += norm * norm
torch.distributed.all_reduce(norm2)
norm = norm2 ** 0.5
return norm.item()
def print_model(model):
print_rank_0("print-model")
for name, param in model.named_parameters():
if param.requires_grad:
#print("{} {}".format(name, param.data), flush=True)
print_rank_0("{} {}".format(name, param.data))
return
def print_grads(model):
print_rank_0("print-grads")
for name, param in model.named_parameters():
if param.grad is None:
continue
print_rank_0("{} {}".format(name, param.grad))
......@@ -91,31 +91,16 @@ def forward_step(data_iterator, model, input_tensor):
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
#print_rank_0(query_tokens)
#print_rank_0(context_tokens)
#print_rank_0(torch.sum(query_types))
#print_rank_0(torch.sum(query_mask))
#print_rank_0(torch.sum(context_types))
#print_rank_0(torch.sum(context_mask))
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
# Forward model.
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
#print_rank_0(query_logits)
#print_rank_0(context_logits)
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
#global_batch_size = micro_batch_size
#all_query_logits = query_logits
#all_context_logits = context_logits
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
......@@ -141,17 +126,10 @@ def forward_step(data_iterator, model, input_tensor):
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
#retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
#retrieval_loss = retrieval_loss.float()
#averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
#print_rank_0(loss)
#print_rank_0(stats_dict)
#sys.exit()
return loss, stats_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment