Commit bfc20ecf authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed isse from Initializing ICT from pretrained BERT model

parent 0295bb89
...@@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func):
'init_state_dict_from_bert'): 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert() unwrapped_model.init_state_dict_from_bert()
if args.fp16:
optimizer._model_params_to_master_params()
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
...@@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator,
if args.fp16: if args.fp16:
optimizer.update_master_grads() optimizer.update_master_grads()
timers('backward-master-grad').stop() timers('backward-master-grad').stop()
grad_norm_local = None
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start() timers('backward-clip-grad').start()
...@@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator, ...@@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator,
mpu.clip_grad_norm(parameters, args.clip_grad, mpu.clip_grad_norm(parameters, args.clip_grad,
parameter_names=parameter_names) parameter_names=parameter_names)
else: else:
optimizer.clip_master_grads(args.clip_grad) grad_norm_local = optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop() timers('backward-clip-grad').stop()
#print_rank_0("print-grad_norm_local {}".format(grad_norm_local))
#print_rank_0("after backward") #print_rank_0("after backward")
#print_grads(model) #print_grads(model)
print_model(model) #print_model(model)
print_rank_0(params_global_norm(model)) #print_rank_0(params_global_norm(model))
print_rank_0(params_grad_norm(model)) #print_rank_0(params_grad_norm(model))
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
...@@ -678,9 +681,11 @@ def train_step(forward_step_func, data_iterator, ...@@ -678,9 +681,11 @@ def train_step(forward_step_func, data_iterator,
#print_rank_0("after optimizer") #print_rank_0("after optimizer")
#print_model(model) #print_model(model)
print_rank_0(params_global_norm(model)) #print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model)) #print_rank_0(params_grad_norm(model))
#sys.exit() #sys.exit()
#print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow))
# Update learning rate. # Update learning rate.
skipped_iter = 0 skipped_iter = 0
...@@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
#print_rank_0("Check betas before iterations")
#for group in optimizer.optimizer.param_groups:
# print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps']))
timers('interval time').start() timers('interval time').start()
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
......
...@@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor):
micro_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1 # recall we assert that tensor_model_parallel_size == 1
#global_batch_size = dist.get_world_size() * micro_batch_size global_batch_size = dist.get_world_size() * micro_batch_size
#all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
#all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
global_batch_size = micro_batch_size #global_batch_size = micro_batch_size
all_query_logits = query_logits #all_query_logits = query_logits
all_context_logits = context_logits #all_context_logits = context_logits
# scores are inner products between query and context embeddings # scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits, retrieval_scores = torch.matmul(all_query_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment