test_autochunk_codegen.py 3.93 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
from functools import partial

oahzxl's avatar
init  
oahzxl committed
3
import pytest
oahzxl's avatar
oahzxl committed
4
import torch
oahzxl's avatar
oahzxl committed
5
import torch.fx
oahzxl's avatar
init  
oahzxl committed
6
import torch.multiprocessing as mp
oahzxl's avatar
oahzxl committed
7

oahzxl's avatar
init  
oahzxl committed
8
9
import colossalai
from colossalai.core import global_context as gpc
oahzxl's avatar
oahzxl committed
10
from colossalai.fx import ColoTracer
oahzxl's avatar
oahzxl committed
11
from colossalai.fx._compatibility import is_compatible_with_meta
oahzxl's avatar
oahzxl committed
12
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
oahzxl's avatar
init  
oahzxl committed
13
from colossalai.fx.graph_module import ColoGraphModule
oahzxl's avatar
oahzxl committed
14
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
oahzxl's avatar
oahzxl committed
15
from colossalai.fx.profiler import MetaTensor
oahzxl's avatar
oahzxl committed
16
from colossalai.utils import free_port
oahzxl's avatar
oahzxl committed
17
from tests.test_autochunk.evoformer.evoformer import evoformer_base
oahzxl's avatar
oahzxl committed
18

oahzxl's avatar
oahzxl committed
19
if CODEGEN_AVAILABLE and is_compatible_with_meta():
oahzxl's avatar
oahzxl committed
20
    from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
oahzxl's avatar
oahzxl committed
21
    from colossalai.fx.profiler import MetaTensor
oahzxl's avatar
oahzxl committed
22

oahzxl's avatar
oahzxl committed
23
24

def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
oahzxl's avatar
oahzxl committed
25
26
27
28
29
30
31
32
33
34
35
36
37
    # for memory test
    # torch.cuda.reset_peak_memory_stats()
    # now_mem = torch.cuda.memory_allocated() / 1024**2
    # with torch.no_grad():
    #     node1 = node.clone()
    #     pair1 = pair.clone()
    #     gm(node1, pair1)
    # new_now_mem = torch.cuda.memory_allocated() / 1024**2
    # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
    # print(
    #     "autochunk now mem:%.2f max mem:%.2f"
    #     % (new_now_mem - now_mem, new_max_mem - now_mem)
    # )
oahzxl's avatar
oahzxl committed
38

oahzxl's avatar
init  
oahzxl committed
39
    # test forward
40
41
42
    with torch.no_grad():
        non_fx_out = model(node, pair)
        fx_out = gm(node, pair)
oahzxl's avatar
oahzxl committed
43

oahzxl's avatar
oahzxl committed
44
45
46
47
48
49
50
51
52
53
    assert torch.allclose(
        non_fx_out[0], fx_out[0], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[0] - fx_out[0])
    )
    assert torch.allclose(
        non_fx_out[1], fx_out[1], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[1] - fx_out[1])
    )
oahzxl's avatar
init  
oahzxl committed
54
55


oahzxl's avatar
oahzxl committed
56
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
oahzxl's avatar
init  
oahzxl committed
57
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
oahzxl's avatar
oahzxl committed
58
59
60
61
62
63
64
65
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )
oahzxl's avatar
init  
oahzxl committed
66
67

    # build model and input
oahzxl's avatar
oahzxl committed
68
    model = evoformer_base().cuda()
oahzxl's avatar
oahzxl committed
69
70
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()
oahzxl's avatar
init  
oahzxl committed
71
72

    # trace the module and replace codegen
oahzxl's avatar
oahzxl committed
73
74
75
76
77
78
79
80
81
82
83
84
    graph = ColoTracer().trace(
        model,
        meta_args={
            "node": node.to(torch.device("meta")),
            "pair": pair.to(torch.device("meta")),
        },
    )
    gm_prop = torch.fx.symbolic_trace(model)  # must use symbolic_trace
    interp = MetaInfoProp(gm_prop)
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
oahzxl committed
85
86
87
88

    # now run it twice to get meta info in graph module, not necessary
    gm = torch.fx.GraphModule(model, graph)
    interp = MetaInfoProp(gm)
oahzxl's avatar
oahzxl committed
89
90
91
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
init  
oahzxl committed
92

oahzxl's avatar
oahzxl committed
93
    codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
oahzxl's avatar
oahzxl committed
94
    graph.set_codegen(codegen)
oahzxl's avatar
oahzxl committed
95
    gm = ColoGraphModule(model, graph)
oahzxl's avatar
init  
oahzxl committed
96
97
    gm.recompile()

oahzxl's avatar
oahzxl committed
98
99
100
    # assert we have inserted chunk
    code = graph.python_code("self").src
    assert "chunk_size" in code
oahzxl's avatar
oahzxl committed
101
    # print(code)
oahzxl's avatar
init  
oahzxl committed
102

oahzxl's avatar
oahzxl committed
103
    _test_fwd(model, gm, node, pair)
oahzxl's avatar
init  
oahzxl committed
104
105
106
    gpc.destroy()


oahzxl's avatar
oahzxl committed
107
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0')
oahzxl's avatar
oahzxl committed
108
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
oahzxl's avatar
oahzxl committed
109
110
111
112
113
114
115
116
117
118
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_codegen(msa_len, pair_len, max_memory):
    run_func = partial(
        _test_autochunk_codegen,
        msa_len=msa_len,
        pair_len=pair_len,
        max_memory=max_memory,
    )
    mp.spawn(run_func, nprocs=1)
oahzxl's avatar
init  
oahzxl committed
119
120
121


if __name__ == "__main__":
oahzxl's avatar
oahzxl committed
122
    _test_autochunk_codegen(0, 32, 64, 25)