Commit df6e3cd7 authored by mshoeybi's avatar mshoeybi
Browse files

set grads to none for the contig. buffers

parent 68797d90
...@@ -354,8 +354,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -354,8 +354,7 @@ def train_step(forward_step_func, data_iterator,
if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp: if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
for partition in model: for partition in model:
partition.zero_grad_buffer() partition.zero_grad_buffer()
else: optimizer.zero_grad()
optimizer.zero_grad()
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment