"tests/test_zero/test_init_context.py" did not exist on "104cbbb313348e04cf83bda9f2dbfbe3b0f369fb"
test_offload_codegen.py 6.48 KB
Newer Older
Boyuan Yao's avatar
Boyuan Yao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import copy
import torch
import torch.nn.functional as F
import pytest
import torch.multiprocessing as mp
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule

try:
    from colossalai.fx.codegen import ActivationCheckpointCodeGen
    with_codegen = True
except:
    # fall back to older pytorch version
    from colossalai.fx.codegen import python_code_with_activation_checkpoint
    with_codegen = False


class MyNet(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()
26
        self.linear0 = torch.nn.Linear(4, 4)
Boyuan Yao's avatar
Boyuan Yao committed
27
28
29
30
31
        self.linear1 = torch.nn.Linear(4, 4)
        self.linear2 = torch.nn.Linear(4, 4)
        self.linear3 = torch.nn.Linear(4, 4)
        self.linear4 = torch.nn.Linear(4, 4)
        self.linear5 = torch.nn.Linear(4, 4)
32
        self.linear6 = torch.nn.Linear(4, 4)
Boyuan Yao's avatar
Boyuan Yao committed
33
34

    def forward(self, x):
35
        x = self.linear0(x)
Boyuan Yao's avatar
Boyuan Yao committed
36
37
38
39
40
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.linear4(x)
        x = self.linear5(x)
41
        x = self.linear6(x)
Boyuan Yao's avatar
Boyuan Yao committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        return x


def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool:
    for m_p, gm_p in zip(m.parameters(), gm.parameters()):
        if not torch.allclose(m_p.grad, gm_p.grad):
            return False
    return True


def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor):

    # test forward
    non_fx_out = model(data)
    fx_out = gm(data)
    assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output"

    # test barckward
    loss0 = non_fx_out.sum()
    loss0.backward()
    loss1 = fx_out.sum()
    loss1.backward()
    assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"


def _run_offload_codegen(rank):
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
    colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

    # build model and input
    model = MyNet().cuda()
    data = torch.rand(4, 4).cuda()

    # trace the module and replace codegen
    tracer = ColoTracer(trace_act_ckpt=True)
    graph = tracer.trace(model)
    codegen = ActivationCheckpointCodeGen()
    graph.set_codegen(codegen)

    # annotate the activation offload part
    # also annotate the activation_checkpoint so we could test both types
    # of input offload
    for node in graph.nodes:
85
        if node.name == "linear0":
86
            setattr(node, "activation_offload", [0, True, False])
87
        if node.name == "linear1":
88
            setattr(node, "activation_offload", [0, True, False])
Boyuan Yao's avatar
Boyuan Yao committed
89
        if node.name == "linear2":
90
            setattr(node, "activation_offload", [1, True, True])
Boyuan Yao's avatar
Boyuan Yao committed
91
        if node.name == "linear4":
92
            setattr(node, "activation_offload", [2, False, True])
93
        if node.name == "linear5":
Boyuan Yao's avatar
Boyuan Yao committed
94
            setattr(node, "activation_checkpoint", [0])
95
            setattr(node, "activation_offload", True)
Boyuan Yao's avatar
Boyuan Yao committed
96
97
98
99
100
101

    gm = ColoGraphModule(copy.deepcopy(model), graph)
    gm.recompile()

    # assert we have all the components
    code = graph.python_code("self").src
102
    assert "def pack_hook_input(self, x):" in code and \
Boyuan Yao's avatar
Boyuan Yao committed
103
    "def unpack_hook(self, packed):" in code and \
104
105
106
107
108
109
110
    "def pack_hook_no_input(self, x):" in code and \
    "setattr(x, 'offload', True)" in code and \
    "setattr(linear3, 'offload', False)" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
    "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
Boyuan Yao's avatar
Boyuan Yao committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    _test_fwd_and_bwd(model, gm, data)
    gpc.destroy()


@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
    mp.spawn(_run_offload_codegen, nprocs=1)


def _run_offload_codegen_torch11(rank):
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
    colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')

    # build model and input
    model = MyNet().cuda()
    data = torch.rand(4, 4).cuda()

    # trace the module and replace codegen
    tracer = ColoTracer(trace_act_ckpt=True)
    graph = tracer.trace(model)

    # replace a bound method of an object
    graph._python_code = python_code_with_activation_checkpoint.__get__(graph)

    # annotate the activation offload part
    # also annotate the activation_checkpoint so we could test both types
    # of input offload
    for node in graph.nodes:
140
        if node.name == "linear0":
141
            setattr(node, "activation_offload", [0, True, False])
142
        if node.name == "linear1":
143
            setattr(node, "activation_offload", [0, True, False])
Boyuan Yao's avatar
Boyuan Yao committed
144
        if node.name == "linear2":
145
            setattr(node, "activation_offload", [1, True, True])
Boyuan Yao's avatar
Boyuan Yao committed
146
        if node.name == "linear4":
147
            setattr(node, "activation_offload", [2, False, True])
148
        if node.name == "linear5":
Boyuan Yao's avatar
Boyuan Yao committed
149
            setattr(node, "activation_checkpoint", [0])
150
            setattr(node, "activation_offload", True)
Boyuan Yao's avatar
Boyuan Yao committed
151
152
153
154
155
156

    gm = ColoGraphModule(copy.deepcopy(model), graph)
    gm.recompile()

    # assert we have all the components
    code = graph.python_code("self").src
157
    assert "def pack_hook_input(self, x):" in code and \
Boyuan Yao's avatar
Boyuan Yao committed
158
    "def unpack_hook(self, packed):" in code and \
159
160
161
162
163
164
165
    "def pack_hook_no_input(self, x):" in code and \
    "setattr(x, 'offload', True)" in code and \
    "setattr(linear3, 'offload', False)" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
    "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
Boyuan Yao's avatar
Boyuan Yao committed
166
167
168
169
170
171
172
173
174
175
176
177

    _test_fwd_and_bwd(model, gm, data)
    gpc.destroy()


@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
def test_act_ckpt_python_code_torch11():
    mp.spawn(_run_offload_codegen_torch11, nprocs=1)


if __name__ == "__main__":
    _run_offload_codegen(0)