test_autochunk_codegen.py 3.73 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
from functools import partial

oahzxl's avatar
init  
oahzxl committed
3
import pytest
oahzxl's avatar
oahzxl committed
4
import torch
oahzxl's avatar
oahzxl committed
5
import torch.fx
oahzxl's avatar
init  
oahzxl committed
6
import torch.multiprocessing as mp
oahzxl's avatar
oahzxl committed
7

oahzxl's avatar
init  
oahzxl committed
8
import colossalai
oahzxl's avatar
oahzxl committed
9
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
oahzxl's avatar
init  
oahzxl committed
10
from colossalai.core import global_context as gpc
oahzxl's avatar
oahzxl committed
11
from colossalai.fx import ColoTracer
oahzxl's avatar
oahzxl committed
12
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
oahzxl's avatar
init  
oahzxl committed
13
from colossalai.fx.graph_module import ColoGraphModule
oahzxl's avatar
oahzxl committed
14
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
oahzxl's avatar
oahzxl committed
15
from colossalai.fx.profiler import MetaTensor
oahzxl's avatar
oahzxl committed
16
from colossalai.utils import free_port
oahzxl's avatar
oahzxl committed
17
from tests.test_autochunk.evoformer.evoformer import evoformer_base
oahzxl's avatar
oahzxl committed
18
19
20


def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
oahzxl's avatar
oahzxl committed
21
22
23
24
25
26
27
28
29
30
31
32
33
    # for memory test
    # torch.cuda.reset_peak_memory_stats()
    # now_mem = torch.cuda.memory_allocated() / 1024**2
    # with torch.no_grad():
    #     node1 = node.clone()
    #     pair1 = pair.clone()
    #     gm(node1, pair1)
    # new_now_mem = torch.cuda.memory_allocated() / 1024**2
    # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
    # print(
    #     "autochunk now mem:%.2f max mem:%.2f"
    #     % (new_now_mem - now_mem, new_max_mem - now_mem)
    # )
oahzxl's avatar
oahzxl committed
34

oahzxl's avatar
init  
oahzxl committed
35
    # test forward
36
37
38
    with torch.no_grad():
        non_fx_out = model(node, pair)
        fx_out = gm(node, pair)
oahzxl's avatar
oahzxl committed
39

oahzxl's avatar
oahzxl committed
40
41
42
43
44
45
46
47
48
49
    assert torch.allclose(
        non_fx_out[0], fx_out[0], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[0] - fx_out[0])
    )
    assert torch.allclose(
        non_fx_out[1], fx_out[1], atol=1e-4
    ), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
        torch.abs(non_fx_out[1] - fx_out[1])
    )
oahzxl's avatar
init  
oahzxl committed
50
51


oahzxl's avatar
oahzxl committed
52
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
oahzxl's avatar
init  
oahzxl committed
53
    # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
oahzxl's avatar
oahzxl committed
54
55
56
57
58
59
60
61
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )
oahzxl's avatar
init  
oahzxl committed
62
63

    # build model and input
oahzxl's avatar
oahzxl committed
64
    model = evoformer_base().cuda()
oahzxl's avatar
oahzxl committed
65
66
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()
oahzxl's avatar
init  
oahzxl committed
67
68

    # trace the module and replace codegen
oahzxl's avatar
oahzxl committed
69
70
71
72
73
74
75
76
77
78
79
80
    graph = ColoTracer().trace(
        model,
        meta_args={
            "node": node.to(torch.device("meta")),
            "pair": pair.to(torch.device("meta")),
        },
    )
    gm_prop = torch.fx.symbolic_trace(model)  # must use symbolic_trace
    interp = MetaInfoProp(gm_prop)
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
oahzxl committed
81
82
83
84

    # now run it twice to get meta info in graph module, not necessary
    gm = torch.fx.GraphModule(model, graph)
    interp = MetaInfoProp(gm)
oahzxl's avatar
oahzxl committed
85
86
87
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
    )
oahzxl's avatar
init  
oahzxl committed
88

oahzxl's avatar
oahzxl committed
89
    codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
oahzxl's avatar
oahzxl committed
90
    graph.set_codegen(codegen)
oahzxl's avatar
oahzxl committed
91
    gm = ColoGraphModule(model, graph)
oahzxl's avatar
init  
oahzxl committed
92
93
    gm.recompile()

oahzxl's avatar
oahzxl committed
94
95
96
    # assert we have inserted chunk
    code = graph.python_code("self").src
    assert "chunk_size" in code
oahzxl's avatar
oahzxl committed
97
    # print(code)
oahzxl's avatar
init  
oahzxl committed
98

oahzxl's avatar
oahzxl committed
99
    _test_fwd(model, gm, node, pair)
oahzxl's avatar
init  
oahzxl committed
100
101
102
    gpc.destroy()


oahzxl's avatar
oahzxl committed
103
@pytest.mark.skipif(not CODEGEN_AVAILABLE, reason='torch version is lower than 1.12.0')
oahzxl's avatar
oahzxl committed
104
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
oahzxl's avatar
oahzxl committed
105
106
107
108
109
110
111
112
113
114
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
def test_autochunk_codegen(msa_len, pair_len, max_memory):
    run_func = partial(
        _test_autochunk_codegen,
        msa_len=msa_len,
        pair_len=pair_len,
        max_memory=max_memory,
    )
    mp.spawn(run_func, nprocs=1)
oahzxl's avatar
init  
oahzxl committed
115
116
117


if __name__ == "__main__":
oahzxl's avatar
oahzxl committed
118
    _test_autochunk_codegen(0, 32, 64, None)