Commit 18f2eaee authored by Michael Carilli's avatar Michael Carilli
Browse files

Clarifying documentation on gradient accumulation

parent 656d14b0
......@@ -150,18 +150,15 @@ gradient clipping via the `instructions above`_::
# will be averaged over that many iterations):
loss = loss/iters_to_accumulate
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# Every iters_to_accumulate iterations, call step() and reset gradients:
if iter%iters_to_accumulate == 0:
# Every iters_to_accumulate iterations, unscale and step
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# Gradient clipping if desired:
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
optimizer.step()
optimizer.zero_grad()
else:
# Otherwise, accumulate gradients, don't unscale or step.
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
As a minor performance optimization, you can pass ``delay_unscale=True``
to ``amp.scale_loss`` until you're ready to ``step()``. You should only attempt ``delay_unscale=True``
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment