activation_checkpointing.rst 2.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Enhanced Activation Checkpointing
=================================

Activation checkpointing is a technique used to reduce GPU memory usage during training. This is 
done by avoiding the need to store intermediate activation tensors during the forward pass. Instead, 
the forward pass is recomputed by keeping track of the original input during the backward pass. 
There is a slight increase in computation cost (about 33%) but this reduces the need to store 
large activation tensors which allows us to increase the batch size and thereby the net throughput 
of the model.


Activation checkpointing is implemented by overriding `torch.autograd.Function`. In the `forward` 
function which handles the forward pass of the module, using `no_grad`, we can prevent the creation 
of the forward graph and materialization of intermediate activation tensors for a long period of 
time (i.e till the backward pass). Instead, during the backward pass, the forward pass is executed 
again followed by the backward pass. The inputs to the forward pass are saved using a context object 
that is then accessed in the backward pass to retrieve the original inputs. We also save the 
Random Number Generator(RNG) state for the forward and backward passes as required for Dropout layers.

The above functionality is already implemented as part of the `torch.utils.checkpoint.checkpoint_wrapper` 
API whereby different modules in the forward pass can be wrapped. The wrapper in FairScale offers 
functionality beyond that provided by the PyTorch API specifically you can use 
`fairscale.nn.checkpoint.checkpoint_wrapper` to wrap a `nn.Module`, handle kwargs in the forward 
pass, offload intermediate activations to the CPU and handle non-tensor outputs returned from the 
forward function.

Best practices for `fairscale.nn.checkpoint.checkpoint_wrapper`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

1. Memory savings depends entirely on the model and the segmentation of checkpoint wrapping. 
Each backprop consists of several mini-forward and backprop passes. The gain is entirely dependent 
on the memory footprint of the layer’s activations. 

2. When using BatchNormalization you may need to freeze the calculation of statistics since we run 
the forward pass twice.

3. Ensure that the input tensor’s `requires_grad` field is set to True. In order to trigger the 
backward function, the output needs to have this field set. By setting it on the input tensor we 
ensure that this is propagated to the output and the `backward` function is triggered.