Commit bd0db55e authored by Michael Carilli's avatar Michael Carilli
Browse files

Adding documentation on custom batch casting

parent ee69ab64
......@@ -117,3 +117,43 @@ Pass ``delay_unscale=True`` to ``amp.scale_loss`` until you're ready to ``step()
# Otherwise, just accumulate gradients, don't unscale or step.
with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss:
scaled_loss.backward()
Custom data batch types
-----------------------
The intention of Amp is that you never need to cast your input data manually, regardless of
``opt_level``. Amp accomplishes this by patching any models' ``forward`` methods to cast
incoming data appropriately for the ``opt_level``. But to cast incoming data,
Amp needs to know how. The patched ``forward`` will recognize and cast floating-point Tensors
(non-floating-point Tensors like IntTensors are not touched) and
Python containers of floating-point Tensors. However, if you wrap your Tensors in a custom class,
the casting logic doesn't know how to drill
through the tough custom shell to access and cast the juicy Tensor meat within. You need to tell
Amp how to cast your custom batch class, by assigning it a ``to`` method that accepts a ``torch.dtype``
(e.g., ``torch.float16`` or ``torch.float32``) and returns an instance of the custom batch cast to
``dtype``. The patched ``forward`` checks for the presence of your ``to`` method, and will
invoke it with the correct type for the ``opt_level``.
Example::
class CustomData(object):
def __init__(self):
self.tensor = torch.cuda.FloatTensor([1,2,3])
def to(self, dtype):
self.tensor = self.tensor.to(dtype)
return self
.. warning::
Amp also forwards numpy ndarrays without casting them. If you send input data as a raw, unwrapped
ndarray, then later use it to create a Tensor within your ``model.forward``, this Tensor's type will
not depend on the ``opt_level``, and may or may not be correct. Users are encouraged to pass
castable data inputs (Tensors, collections of Tensors, or custom classes with a ``to`` method)
wherever possible.
.. note::
Amp does not call ``.cuda()`` on any Tensors for you. Amp assumes that your original script
is already set up to move Tensors from the host to the device as needed.
......@@ -180,7 +180,7 @@ Advanced use cases
The unified Amp API supports gradient accumulation across iterations,
multiple backward passes per iteration, multiple models/optimizers,
and custom/user-defined autograd functions. Gradient clipping and GANs also
custom/user-defined autograd functions, and custom data batch classes. Gradient clipping and GANs also
require special treatment, but this treatment does not need to change
for different ``opt_level``\ s. Further details can be found here:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment