test_offload.py 5.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

"""
Testing Offload Module
"""

import contextlib
import copy

import numpy as np
import pytest
import torch

from fairscale.experimental.nn.offload import OffloadModel
18
19
20
21
from fairscale.utils.testing import skip_if_no_cuda, torch_version

if torch_version() >= (1, 8, 0):
    from fairscale.experimental.nn.auto_shard import shard_model
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


def _init():
    torch.cuda.set_device(0)
    torch.manual_seed(0)
    np.random.seed(0)
    device = torch.device("cuda")
    offload_device = torch.device("cpu")
    return device, offload_device


@skip_if_no_cuda
def test_single_run():
    device, offload_device = _init()
    model = _get_model()

38
39
40
41
42
43
44
45
46
47
    peak_mem = {}
    for checkpoint_activation in [True, False]:
        offload_model = OffloadModel(
            model=model,
            device=device,
            offload_device=offload_device,
            num_slices=2,
            checkpoint_activation=checkpoint_activation,
        )
        offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        input = torch.ones(1000, 2).to(device)
        labels = torch.ones(1000, 2).to(device)
        offload_model.train()
        pred = offload_model(input)
        loss_fn = torch.nn.MSELoss(reduction="sum")
        loss = loss_fn(pred, labels)
        loss.backward()
        offload_optimizer.step()
        key = "ca_" + str(checkpoint_activation)
        peak_mem[key] = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]
        print(
            "Peak allocated bytes on cuda:0 for checkpoint_activation "
            + str(checkpoint_activation)
            + ": {:2f}".format(peak_mem[key])
        )
64

65
66
    # TODO(anj-s): We need a better requirement since this fails on CircleCI right now.
    assert peak_mem["ca_True"] <= peak_mem["ca_False"]
67

68
69

def _get_model(num_inputs=2, num_hidden=20, num_layers=10, num_outputs=2):
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    model = torch.nn.Sequential(
        torch.nn.Linear(num_inputs, num_hidden),
        *([torch.nn.Linear(num_hidden, num_hidden) for _ in range(num_layers)]),
        torch.nn.Linear(num_hidden, num_outputs),
    )
    return model


def _check_parity(rmodel, omodel, ropt, oopt, rloss, oloss):

    for oparams, rparams in zip(omodel.parameters(), rmodel.parameters()):
        assert torch.allclose(oparams, rparams, atol=1e-2), f"Model params are different {oparams} {rparams}"

    for o_pg, reg_pg in zip(oopt.param_groups, ropt.param_groups):
        for o_pg, reg_pg in zip(o_pg["params"], reg_pg["params"]):
            assert torch.allclose(
                o_pg, reg_pg, atol=1e-2
            ), f"Model parameters differ in between Offlad and Vanilla {[o_pg]} {reg_pg}"

        for o_buf, reg_buf in zip(omodel.buffers(), rmodel.buffers()):
            assert torch.allclose(o_buf, reg_buf, atol=1e-2), "Model buffers differ in between Offload and Vanilla."


def _get_fp16_context(use_fp16=False):
    if use_fp16:
        return torch.cuda.amp.autocast()
    else:
        return contextlib.nullcontext()


def _train(model, optimizer, use_fp16, device):

    inputs = torch.ones(32, 2).to(device)
    labels = torch.ones(32, 2).to(device)
    loss_fn = torch.nn.MSELoss(reduction="sum")
    model.train()
    with _get_fp16_context(use_fp16):
        pred = model(inputs)
        loss = loss_fn(pred, labels)
        loss.backward()
    optimizer.step()
    return model, optimizer, loss


def _train_reg_model(model, device, offload_device, use_fp16=False):
    reg_model = copy.deepcopy(model)
    reg_model = reg_model.cuda()
    reg_optimizer = torch.optim.SGD(reg_model.parameters(), lr=0.001)
    return _train(reg_model, reg_optimizer, use_fp16, device)


def _train_offload_model(
    model, device, offload_device, use_fp16=False, checkpoint_activation=False, num_microbatches=1
):
    omodel = copy.deepcopy(model)
    offload_model = OffloadModel(
126
        model=omodel,
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        device=device,
        offload_device=offload_device,
        num_slices=2,
        checkpoint_activation=checkpoint_activation,
        num_microbatches=num_microbatches,
    )
    offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001)
    return _train(offload_model, offload_optimizer, use_fp16, device)


@skip_if_no_cuda
@pytest.mark.parametrize("use_fp16", [True, False])
@pytest.mark.parametrize("checkpoint_activation", [True, False])
@pytest.mark.parametrize("num_microbatches", [1, 5])
141
142
143
144
145
@pytest.mark.parametrize("use_auto_shard", [True, False])
def test_correctness(use_fp16, checkpoint_activation, num_microbatches, use_auto_shard):
    if use_auto_shard and torch_version() < (1, 8, 0):
        pytest.skip("auto_shard requires torch version >= 1.8.0")

146
147
148
149
150
151
152
153
    if (use_fp16 or checkpoint_activation) and not hasattr(torch.cuda.amp, "custom_fwd"):
        pytest.skip(f"AMP APIs are not supported in torch version {torch.__version__}")

    if not checkpoint_activation and num_microbatches > 1:
        pytest.skip("We only support microbatches with activation offloading.")

    device, offload_device = _init()
    model = _get_model()
154
155
156
157
158
    if use_auto_shard:
        offload_model = shard_model(model)
    else:
        offload_model = model

159
160
    rmodel, ropt, rloss = _train_reg_model(model, device, offload_device)
    omodel, oopt, oloss = _train_offload_model(
161
        offload_model,
162
163
164
165
166
167
168
        device,
        offload_device,
        use_fp16=use_fp16,
        checkpoint_activation=checkpoint_activation,
        num_microbatches=num_microbatches,
    )
    _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)