test_offload_codegen.py 6.45 KB
Newer Older
Boyuan Yao's avatar
Boyuan Yao committed
1
import copy
2

Boyuan Yao's avatar
Boyuan Yao committed
3
import pytest
4
import torch
Boyuan Yao's avatar
Boyuan Yao committed
5
import torch.multiprocessing as mp
6
import torch.nn.functional as F
Boyuan Yao's avatar
Boyuan Yao committed
7
from torch.fx import GraphModule
8

Boyuan Yao's avatar
Boyuan Yao committed
9
10
import colossalai
from colossalai.core import global_context as gpc
11
from colossalai.fx import ColoTracer
Boyuan Yao's avatar
Boyuan Yao committed
12
from colossalai.fx.graph_module import ColoGraphModule
13
from colossalai.utils import free_port
Boyuan Yao's avatar
Boyuan Yao committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27

try:
    from colossalai.fx.codegen import ActivationCheckpointCodeGen
    with_codegen = True
except:
    # fall back to older pytorch version
    from colossalai.fx.codegen import python_code_with_activation_checkpoint
    with_codegen = False


class MyNet(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
28
        self.linear0 = torch.nn.Linear(4, 4)
Boyuan Yao's avatar
Boyuan Yao committed
29
30
31
32
33
        self.linear1 = torch.nn.Linear(4, 4)
        self.linear2 = torch.nn.Linear(4, 4)
        self.linear3 = torch.nn.Linear(4, 4)
        self.linear4 = torch.nn.Linear(4, 4)
        self.linear5 = torch.nn.Linear(4, 4)
34
        self.linear6 = torch.nn.Linear(4, 4)
Boyuan Yao's avatar
Boyuan Yao committed
35
36

    def forward(self, x):
37
        x = self.linear0(x)
Boyuan Yao's avatar
Boyuan Yao committed
38
39
40
41
42
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.linear4(x)
        x = self.linear5(x)
43
        x = self.linear6(x)
Boyuan Yao's avatar
Boyuan Yao committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        return x


def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
    for m_p, gm_p in zip(m.parameters(), gm.parameters()):
        if not torch.allclose(m_p.grad, gm_p.grad):
            return False
    return True


def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):

    # test forward
    non_fx_out = model(data)
    fx_out = gm(data)
    assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"

    # test barckward
    loss0 = non_fx_out.sum()
    loss0.backward()
    loss1 = fx_out.sum()
    loss1.backward()
    assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"


def _run_offload_codegen(rank):
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
    colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

    # build model and input
    model = MyNet().cuda()
    data = torch.rand(4, 4).cuda()

    # trace the module and replace codegen
    tracer = ColoTracer(trace_act_ckpt=True)
    graph = tracer.trace(model)
    codegen = ActivationCheckpointCodeGen()
    graph.set_codegen(codegen)

    # annotate the activation offload part
    # also annotate the activation_checkpoint so we could test both types
    # of input offload
    for node in graph.nodes:
87
        if node.name == "linear0":
88
            node.meta['activation_offload'] = [0, True, False]
89
        if node.name == "linear1":
90
            node.meta['activation_offload'] = [0, True, False]
Boyuan Yao's avatar
Boyuan Yao committed
91
        if node.name == "linear2":
92
            node.meta['activation_offload'] = [1, True, True]
Boyuan Yao's avatar
Boyuan Yao committed
93
        if node.name == "linear4":
94
            node.meta['activation_offload'] = [2, False, True]
95
        if node.name == "linear5":
96
97
            node.meta['activation_checkpoint'] = [0]
            node.meta['activation_offload'] = True
Boyuan Yao's avatar
Boyuan Yao committed
98
99
100
101
102
103

    gm = ColoGraphModule(copy.deepcopy(model), graph)
    gm.recompile()

    # assert we have all the components
    code = graph.python_code("self").src
104
    assert "def pack_hook_input(self, x):" in code and \
Boyuan Yao's avatar
Boyuan Yao committed
105
    "def unpack_hook(self, packed):" in code and \
106
107
108
109
110
111
112
    "def pack_hook_no_input(self, x):" in code and \
    "setattr(x, 'offload', True)" in code and \
    "setattr(linear3, 'offload', False)" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
    "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
Boyuan Yao's avatar
Boyuan Yao committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    _test_fwd_and_bwd(model, gm, data)
    gpc.destroy()


@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
    mp.spawn(_run_offload_codegen, nprocs=1)


def _run_offload_codegen_torch11(rank):
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
    colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

    # build model and input
    model = MyNet().cuda()
    data = torch.rand(4, 4).cuda()

    # trace the module and replace codegen
    tracer = ColoTracer(trace_act_ckpt=True)
    graph = tracer.trace(model)

    # replace a bound method of an object
    graph._python_code = python_code_with_activation_checkpoint.__get__(graph)

    # annotate the activation offload part
    # also annotate the activation_checkpoint so we could test both types
    # of input offload
    for node in graph.nodes:
142
        if node.name == "linear0":
143
            node.meta['activation_offload'] = [0, True, False]
144
        if node.name == "linear1":
145
            node.meta['activation_offload'] = [0, True, False]
Boyuan Yao's avatar
Boyuan Yao committed
146
        if node.name == "linear2":
147
            node.meta['activation_offload'] = [1, True, True]
Boyuan Yao's avatar
Boyuan Yao committed
148
        if node.name == "linear4":
149
            node.meta['activation_offload'] = [2, False, True]
150
        if node.name == "linear5":
151
152
            node.meta['activation_checkpoint'] = [0]
            node.meta['activation_offload'] = True
Boyuan Yao's avatar
Boyuan Yao committed
153
154
155
156
157
158

    gm = ColoGraphModule(copy.deepcopy(model), graph)
    gm.recompile()

    # assert we have all the components
    code = graph.python_code("self").src
159
    assert "def pack_hook_input(self, x):" in code and \
Boyuan Yao's avatar
Boyuan Yao committed
160
    "def unpack_hook(self, packed):" in code and \
161
162
163
164
165
166
167
    "def pack_hook_no_input(self, x):" in code and \
    "setattr(x, 'offload', True)" in code and \
    "setattr(linear3, 'offload', False)" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
    "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
Boyuan Yao's avatar
Boyuan Yao committed
168
169
170
171
172
173
174
175
176
177
178
179

    _test_fwd_and_bwd(model, gm, data)
    gpc.destroy()


@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
def test_act_ckpt_python_code_torch11():
    mp.spawn(_run_offload_codegen_torch11, nprocs=1)


if __name__ == "__main__":
    _run_offload_codegen(0)