Commit 895c5167 authored by Michael Carilli's avatar Michael Carilli
Browse files

Clarifying docs on gradient clipping

parent 42180bd9
...@@ -15,31 +15,43 @@ is under construction. ...@@ -15,31 +15,43 @@ is under construction.
Gradient clipping Gradient clipping
----------------- -----------------
If Amp uses master params distinct from the model params, Amp calls the params owned directly by the optimizer's ``param_groups`` the "master params."
then the params ``step()``\ ed by the optimizer are the master params,
and it is the master gradients (rather than the model gradients) that must be clipped.
If Amp is not using master params distinct from the model params, then the optimizer These master params may be fully or partially distinct from ``model.parameters()``.
directly steps the model params, and the model grads must be clipped. For example, with `opt_level="O2"`_, ``amp.initialize`` casts most model params to FP16,
creates an FP32 master param outside the model for each newly-FP16 model param,
and updates the optimizer's ``param_groups`` to point to these FP32 params.
In both cases, correct practice is to clip the gradients of the params that are about to be stepped **by the optimizer** (which may be distinct from ``model.parameters()``). The master params owned by the optimizer's ``param_groups`` may also fully coincide with the
model params, which is typically true for ``opt_level``\s ``O0``, ``O1``, and ``O3``.
Also, if Amp uses loss scaling, gradients must be clipped after they have been unscaled. In all cases, correct practice is to clip the gradients of the params that are guaranteed to be
owned **by the optimizer's** ``param_groups``, instead of those retrieved via ``model.parameters()``.
The following pattern accounts for all possibilities, and should be correct for Also, if Amp uses loss scaling, gradients must be clipped after they have been unscaled
any ``opt_level``:: (which occurs during exit from the ``amp.scale_loss`` context manager).
The following pattern should be correct for any ``opt_level``::
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
# Gradients are unscaled during context manager exit. # Gradients are unscaled during context manager exit.
# Now it's safe to clip: # Now it's safe to clip. Replace
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# with
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm) torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
# or # or
torch.nn.utils.clip_grad_value_(amp.master_params(optimizer), max_) torch.nn.utils.clip_grad_value_(amp.master_params(optimizer), max_)
Note the use of the utility function ``amp.master_params(optimizer)``, Note the use of the utility function ``amp.master_params(optimizer)``,
which returns a generator-expression that iterates over the which returns a generator-expression that iterates over the
params that the optimizer steps (master params if enabled, otherwise model params). params in the optimizer's ``param_groups``.
Also note that ``clip_grad_norm_(amp.master_params(optimizer), max_norm)`` is invoked
*instead of*, not *in addition to*, ``clip_grad_norm_(model.parameters(), max_norm)``.
.. _`opt_level="O2"`:
https://nvidia.github.io/apex/amp.html#o2-fast-mixed-precision
Custom/user-defined autograd functions Custom/user-defined autograd functions
-------------------------------------- --------------------------------------
......
...@@ -124,6 +124,9 @@ are performed in FP32. ``O1`` also uses dynamic loss scaling, unless overridden ...@@ -124,6 +124,9 @@ are performed in FP32. ``O1`` also uses dynamic loss scaling, unless overridden
``O2`` casts the model weights to FP16, ``O2`` casts the model weights to FP16,
patches the model's ``forward`` method to cast input patches the model's ``forward`` method to cast input
data to FP16, keeps batchnorms in FP32, maintains FP32 master weights, data to FP16, keeps batchnorms in FP32, maintains FP32 master weights,
updates the optimizer's ``param_groups`` so that the ``optimizer.step()``
acts directly on the FP32 weights (followed by FP32 master weight->FP16 model weight
copies if necessary),
and implements dynamic loss scaling (unless overridden). and implements dynamic loss scaling (unless overridden).
Unlike ``O1``, ``O2`` does not patch Torch functions or Tensor methods. Unlike ``O1``, ``O2`` does not patch Torch functions or Tensor methods.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment