"src/vscode:/vscode.git/clone" did not exist on "7a065a9c56438dad282a0709df1b45bf2bd3bd6a"
Commit 9d4c735a authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Improve names of identifiers used for timing in main training loop

parent 8fb2bc8c
...@@ -275,13 +275,11 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g ...@@ -275,13 +275,11 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
input_tensor.retain_grad() input_tensor.retain_grad()
# Backward pass. # Backward pass.
timers('backward-backward').start()
if args.fp16: if args.fp16:
optimizer.backward(output_tensor, update_master_grads=False, optimizer.backward(output_tensor, update_master_grads=False,
output_tensor_grad=output_tensor_grad) output_tensor_grad=output_tensor_grad)
else: else:
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
timers('backward-backward').stop()
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
input_tensor_grad = None input_tensor_grad = None
...@@ -409,10 +407,10 @@ def train_step(forward_step_func, data_iterator, ...@@ -409,10 +407,10 @@ def train_step(forward_step_func, data_iterator,
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('allreduce').start() timers('backward-params-all-reduce').start()
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop() timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure # All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync. # that word_embeddings parameters stay in sync.
...@@ -512,9 +510,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -512,9 +510,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers_to_log.append(name) timers_to_log.append(name)
add_to_logging('forward') add_to_logging('forward')
add_to_logging('backward') add_to_logging('backward')
add_to_logging('backward-backward')
add_to_logging('backward-allreduce')
add_to_logging('backward-master-grad') add_to_logging('backward-master-grad')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-clip-grad') add_to_logging('backward-clip-grad')
add_to_logging('optimizer') add_to_logging('optimizer')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment