Commit 9b558566 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Bugfix in main training loop: Update master_grads only after grads are correctly accumulated

parent 767e6e92
......@@ -400,12 +400,6 @@ def train_step(forward_step_func, data_iterator,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
# Update master gradients.
timers('backward-master-grad').start()
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()
# All-reduce across first and last stages.
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
......@@ -419,6 +413,12 @@ def train_step(forward_step_func, data_iterator,
group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update master gradients.
timers('backward-master-grad').start()
if args.fp16:
optimizer.update_master_grads()
timers('backward-master-grad').stop()
# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
if args.clip_grad > 0.:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment