test_autochunk_codegen.py 3.76 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
from functools import partial

oahzxl's avatar
init  
oahzxl committed
3
import pytest
oahzxl's avatar
oahzxl committed
4
import torch
oahzxl's avatar
oahzxl committed
5
import torch.fx
oahzxl's avatar
init  
oahzxl committed
6
import torch.multiprocessing as mp
oahzxl's avatar
oahzxl committed
7

oahzxl's avatar
init  
oahzxl committed
8
9
import colossalai
from colossalai.core import global_context as gpc
oahzxl's avatar
oahzxl committed
10
from colossalai.fx import ColoTracer
oahzxl's avatar
oahzxl committed
11
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
oahzxl's avatar
init  
oahzxl committed
12
from colossalai.fx.graph_module import ColoGraphModule
oahzxl's avatar
oahzxl committed
13
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
oahzxl's avatar
oahzxl committed
14
from colossalai.fx.profiler import MetaTensor
oahzxl's avatar
oahzxl committed
15
from colossalai.utils import free_port
oahzxl's avatar
oahzxl committed
16
from tests.test_autochunk.evoformer.evoformer import evoformer_base
oahzxl's avatar
oahzxl committed
17

oahzxl's avatar
oahzxl committed
18
19
20
if CODEGEN_AVAILABLE:
    from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen

oahzxl's avatar
oahzxl committed
21
22

def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
oahzxl's avatar
oahzxl committed
23
24
25
26
27
28
29
30
31
32
33
34
35
    # for memory test
    # torch.cuda.reset_peak_memory_stats()
    # now_mem = torch.cuda.memory_allocated() / 1024**2
    # with torch.no_grad():
    #     node1 = node.clone()
    #     pair1 = pair.clone()
    #     gm(node1, pair1)
    # new_now_mem = torch.cuda.memory_allocated() / 1024**2
    # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
    # print(
    #     "autochunk now mem:%.2f max mem:%.2f"
    #     % (new_now_mem - now_mem, new_max_mem - now_mem)
    # )
oahzxl's avatar
oahzxl committed
36

oahzxl's avatar
init  
oahzxl committed
37
    # test forward
38
39
40
    with torch.no_grad():
        non_fx_out = model(node, pair)
        fx_out = gm(node, pair)
oahzxl's avatar
oahzxl committed
41

oahzxl's avatar
oahzxl committed
42
43
44
45
46
47
48
49
50
51
    assert torch.allclose(
        non_fx_out[0], fx_out[0], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[0] - fx_out[0])
    )
    assert torch.allclose(
        non_fx_out[1], fx_out[1], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[1] - fx_out[1])
    )
oahzxl's avatar
init  
oahzxl committed
52
53


oahzxl's avatar
oahzxl committed
54
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
oahzxl's avatar
init  
oahzxl committed
55
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
oahzxl's avatar
oahzxl committed
56
57
58
59
60
61
62
63
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )
oahzxl's avatar
init  
oahzxl committed
64
65

    # build model and input
oahzxl's avatar
oahzxl committed
66
    model = evoformer_base().cuda()
oahzxl's avatar
oahzxl committed
67
68
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()
oahzxl's avatar
init  
oahzxl committed
69
70

    # trace the module and replace codegen
oahzxl's avatar
oahzxl committed
71
72
73
74
75
76
77
78
79
80
81
82
    graph = ColoTracer().trace(
        model,
        meta_args={
            "node": node.to(torch.device("meta")),
            "pair": pair.to(torch.device("meta")),
        },
    )
    gm_prop = torch.fx.symbolic_trace(model)  # must use symbolic_trace
    interp = MetaInfoProp(gm_prop)
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
oahzxl committed
83
84
85
86

    # now run it twice to get meta info in graph module, not necessary
    gm = torch.fx.GraphModule(model, graph)
    interp = MetaInfoProp(gm)
oahzxl's avatar
oahzxl committed
87
88
89
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
init  
oahzxl committed
90

oahzxl's avatar
oahzxl committed
91
    codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
oahzxl's avatar
oahzxl committed
92
    graph.set_codegen(codegen)
oahzxl's avatar
oahzxl committed
93
    gm = ColoGraphModule(model, graph)
oahzxl's avatar
init  
oahzxl committed
94
95
    gm.recompile()

oahzxl's avatar
oahzxl committed
96
97
98
    # assert we have inserted chunk
    code = graph.python_code("self").src
    assert "chunk_size" in code
oahzxl's avatar
oahzxl committed
99
    # print(code)
oahzxl's avatar
init  
oahzxl committed
100

oahzxl's avatar
oahzxl committed
101
    _test_fwd(model, gm, node, pair)
oahzxl's avatar
init  
oahzxl committed
102
103
104
    gpc.destroy()


oahzxl's avatar
oahzxl committed
105
@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0')
oahzxl's avatar
oahzxl committed
106
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
oahzxl's avatar
oahzxl committed
107
108
109
110
111
112
113
114
115
116
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_codegen(msa_len, pair_len, max_memory):
    run_func = partial(
        _test_autochunk_codegen,
        msa_len=msa_len,
        pair_len=pair_len,
        max_memory=max_memory,
    )
    mp.spawn(run_func, nprocs=1)
oahzxl's avatar
init  
oahzxl committed
117
118
119


if __name__ == "__main__":
oahzxl's avatar
oahzxl committed
120
    _test_autochunk_codegen(0, 32, 64, 25)