Commit bfc20ecf authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed isse from Initializing ICT from pretrained BERT model

parent 0295bb89
......@@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func):
'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert()
if args.fp16:
optimizer._model_params_to_master_params()
return model, optimizer, lr_scheduler
......@@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator,
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()
grad_norm_local = None
# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
......@@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator,
mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names)
else:
optimizer.clip_master_grads(args.clip_grad)
grad_norm_local = optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
#print_rank_0("print-grad_norm_local {}".format(grad_norm_local))
#print_rank_0("after backward")
#print_grads(model)
print_model(model)
print_rank_0(params_global_norm(model))
print_rank_0(params_grad_norm(model))
#print_model(model)
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
# Update parameters.
timers('optimizer').start()
......@@ -678,9 +681,11 @@ def train_step(forward_step_func, data_iterator,
#print_rank_0("after optimizer")
#print_model(model)
print_rank_0(params_global_norm(model))
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
#sys.exit()
#print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow))
# Update learning rate.
skipped_iter = 0
......@@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration = args.iteration
#print_rank_0("Check betas before iterations")
#for group in optimizer.optimizer.param_groups:
# print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps']))
timers('interval time').start()
print_datetime('before the start of training step')
report_memory_flag = True
......
......@@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor):
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
#global_batch_size = dist.get_world_size() * micro_batch_size
#all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
#all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
global_batch_size = micro_batch_size
all_query_logits = query_logits
all_context_logits = context_logits
#global_batch_size = micro_batch_size
#all_query_logits = query_logits
#all_context_logits = context_logits
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment