test_autochunk_codegen.py 3.57 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
from functools import partial

oahzxl's avatar
init  
oahzxl committed
3
import pytest
oahzxl's avatar
oahzxl committed
4
import torch
oahzxl's avatar
oahzxl committed
5
import torch.fx
oahzxl's avatar
init  
oahzxl committed
6
import torch.multiprocessing as mp
oahzxl's avatar
oahzxl committed
7

oahzxl's avatar
init  
oahzxl committed
8
import colossalai
oahzxl's avatar
oahzxl committed
9
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
oahzxl's avatar
init  
oahzxl committed
10
from colossalai.core import global_context as gpc
oahzxl's avatar
oahzxl committed
11
from colossalai.fx import ColoTracer
oahzxl's avatar
init  
oahzxl committed
12
from colossalai.fx.graph_module import ColoGraphModule
oahzxl's avatar
oahzxl committed
13
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
oahzxl's avatar
oahzxl committed
14
from colossalai.fx.profiler import MetaTensor
oahzxl's avatar
oahzxl committed
15
from colossalai.utils import free_port
oahzxl's avatar
oahzxl committed
16
from tests.test_autochunk.evoformer.evoformer import evoformer_base
oahzxl's avatar
oahzxl committed
17
18
19


def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
oahzxl's avatar
oahzxl committed
20
21
22
23
24
25
26
27
28
29
30
31
32
    # for memory test
    # torch.cuda.reset_peak_memory_stats()
    # now_mem = torch.cuda.memory_allocated() / 1024**2
    # with torch.no_grad():
    #     node1 = node.clone()
    #     pair1 = pair.clone()
    #     gm(node1, pair1)
    # new_now_mem = torch.cuda.memory_allocated() / 1024**2
    # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
    # print(
    #     "autochunk now mem:%.2f max mem:%.2f"
    #     % (new_now_mem - now_mem, new_max_mem - now_mem)
    # )
oahzxl's avatar
oahzxl committed
33

oahzxl's avatar
init  
oahzxl committed
34
    # test forward
35
36
37
    with torch.no_grad():
        non_fx_out = model(node, pair)
        fx_out = gm(node, pair)
oahzxl's avatar
oahzxl committed
38

oahzxl's avatar
oahzxl committed
39
40
41
42
43
44
45
46
47
48
    assert torch.allclose(
        non_fx_out[0], fx_out[0], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[0] - fx_out[0])
    )
    assert torch.allclose(
        non_fx_out[1], fx_out[1], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[1] - fx_out[1])
    )
oahzxl's avatar
init  
oahzxl committed
49
50


oahzxl's avatar
oahzxl committed
51
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
oahzxl's avatar
init  
oahzxl committed
52
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
oahzxl's avatar
oahzxl committed
53
54
55
56
57
58
59
60
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )
oahzxl's avatar
init  
oahzxl committed
61
62

    # build model and input
oahzxl's avatar
oahzxl committed
63
    model = evoformer_base().cuda()
oahzxl's avatar
oahzxl committed
64
65
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()
oahzxl's avatar
init  
oahzxl committed
66
67

    # trace the module and replace codegen
oahzxl's avatar
oahzxl committed
68
69
70
71
72
73
74
75
76
77
78
79
    graph = ColoTracer().trace(
        model,
        meta_args={
            "node": node.to(torch.device("meta")),
            "pair": pair.to(torch.device("meta")),
        },
    )
    gm_prop = torch.fx.symbolic_trace(model)  # must use symbolic_trace
    interp = MetaInfoProp(gm_prop)
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
oahzxl committed
80
81
82
83

    # now run it twice to get meta info in graph module, not necessary
    gm = torch.fx.GraphModule(model, graph)
    interp = MetaInfoProp(gm)
oahzxl's avatar
oahzxl committed
84
85
86
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
init  
oahzxl committed
87

oahzxl's avatar
oahzxl committed
88
    codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
oahzxl's avatar
oahzxl committed
89
    graph.set_codegen(codegen)
oahzxl's avatar
oahzxl committed
90
    gm = ColoGraphModule(model, graph)
oahzxl's avatar
init  
oahzxl committed
91
92
    gm.recompile()

oahzxl's avatar
oahzxl committed
93
94
95
    # assert we have inserted chunk
    code = graph.python_code("self").src
    assert "chunk_size" in code
oahzxl's avatar
oahzxl committed
96
    # print(code)
oahzxl's avatar
init  
oahzxl committed
97

oahzxl's avatar
oahzxl committed
98
    _test_fwd(model, gm, node, pair)
oahzxl's avatar
init  
oahzxl committed
99
100
101
    gpc.destroy()


oahzxl's avatar
oahzxl committed
102
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
oahzxl's avatar
oahzxl committed
103
104
105
106
107
108
109
110
111
112
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_codegen(msa_len, pair_len, max_memory):
    run_func = partial(
        _test_autochunk_codegen,
        msa_len=msa_len,
        pair_len=pair_len,
        max_memory=max_memory,
    )
    mp.spawn(run_func, nprocs=1)
oahzxl's avatar
init  
oahzxl committed
113
114
115


if __name__ == "__main__":
oahzxl's avatar
oahzxl committed
116
    _test_autochunk_codegen(0, 32, 64, None)