activation_checkpointing.rst 1.02 KB
Newer Older
1
2
3
Efficient memory usage using Activation Checkpointing
=====================================================

4
Adapted from `torch.utils.checkpoint`, this is a friendlier wrapper for performing activation checkpointing.
5
6

Compared to the PyTorch version, this version wraps a `nn.Module` and allows for all subsequent calls to be
7
checkpointed.
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

.. code-block:: python


    from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper


    class CheckpointModel(nn.Module):

        def __init__(self, **kwargs):
            super().__init__()
            torch.manual_seed(0)  # make sure weights are deterministic.
            self.ffn_module = nn.Sequential(
                nn.Linear(32, 128),
                nn.Dropout(p=0.5),
                nn.Linear(128, 32),
            )
25

26
27
28
29
30
31
            self.ffn_module = checkpoint_wrapper(self.ffn_module, **kwargs)
            self.last_linear = nn.Linear(32, 1)

        def forward(self, input):
            output = self.ffn_module(input)
            return self.last_linear(output)