activation_checkpointing.rst 2.48 KB
Newer Older
1
2
3
Enhanced Activation Checkpointing
=================================

4
5
6
7
8
Activation checkpointing is a technique used to reduce GPU memory usage during training. This is
done by avoiding the need to store intermediate activation tensors during the forward pass. Instead,
the forward pass is recomputed by keeping track of the original input during the backward pass.
There is a slight increase in computation cost (about 33%) but this reduces the need to store
large activation tensors which allows us to increase the batch size and thereby the net throughput
9
10
11
of the model.


12
13
14
15
16
17
Activation checkpointing is implemented by overriding `torch.autograd.Function`. In the `forward`
function which handles the forward pass of the module, using `no_grad`, we can prevent the creation
of the forward graph and materialization of intermediate activation tensors for a long period of
time (i.e till the backward pass). Instead, during the backward pass, the forward pass is executed
again followed by the backward pass. The inputs to the forward pass are saved using a context object
that is then accessed in the backward pass to retrieve the original inputs. We also save the
18
19
Random Number Generator(RNG) state for the forward and backward passes as required for Dropout layers.

20
21
22
23
24
The above functionality is already implemented as part of the `torch.utils.checkpoint.checkpoint_wrapper`
API whereby different modules in the forward pass can be wrapped. The wrapper in FairScale offers
functionality beyond that provided by the PyTorch API specifically you can use
`fairscale.nn.checkpoint.checkpoint_wrapper` to wrap a `nn.Module`, handle kwargs in the forward
pass, offload intermediate activations to the CPU and handle non-tensor outputs returned from the
25
26
27
28
29
forward function.

Best practices for `fairscale.nn.checkpoint.checkpoint_wrapper`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

30
31
32
1. Memory savings depends entirely on the model and the segmentation of checkpoint wrapping.
Each backprop consists of several mini-forward and backprop passes. The gain is entirely dependent
on the memory footprint of the layer’s activations.
33

34
2. When using BatchNormalization you may need to freeze the calculation of statistics since we run
35
36
the forward pass twice.

37
38
3. Ensure that the input tensor’s `requires_grad` field is set to True. In order to trigger the
backward function, the output needs to have this field set. By setting it on the input tensor we
39
ensure that this is propagated to the output and the `backward` function is triggered.