test_evoformer_codegen.py 5.11 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
from functools import partial

oahzxl's avatar
init  
oahzxl committed
3
import pytest
oahzxl's avatar
oahzxl committed
4
import torch
oahzxl's avatar
oahzxl committed
5
import torch.fx
oahzxl's avatar
init  
oahzxl committed
6
import torch.multiprocessing as mp
oahzxl's avatar
oahzxl committed
7

8
9
10
11
12
13
try:
    from fastfold.model.nn.evoformer import EvoformerBlock
    HAS_REPO = True
except:
    HAS_REPO = False

oahzxl's avatar
init  
oahzxl committed
14
15
import colossalai
from colossalai.core import global_context as gpc
oahzxl's avatar
oahzxl committed
16
from colossalai.fx._compatibility import is_compatible_with_meta
oahzxl's avatar
oahzxl committed
17
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
oahzxl's avatar
init  
oahzxl committed
18
from colossalai.fx.graph_module import ColoGraphModule
oahzxl's avatar
oahzxl committed
19
20
21
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port

oahzxl's avatar
oahzxl committed
22
if CODEGEN_AVAILABLE and is_compatible_with_meta():
oahzxl's avatar
oahzxl committed
23
    from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
oahzxl's avatar
oahzxl committed
24
    from colossalai.fx.profiler import MetaTensor
25
    from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
oahzxl's avatar
oahzxl committed
26

oahzxl's avatar
oahzxl committed
27

28
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
oahzxl's avatar
oahzxl committed
29
30
31
32
33
34
35
36
37
38
39
40
41
    # for memory test
    # torch.cuda.reset_peak_memory_stats()
    # now_mem = torch.cuda.memory_allocated() / 1024**2
    # with torch.no_grad():
    #     node1 = node.clone()
    #     pair1 = pair.clone()
    #     gm(node1, pair1)
    # new_now_mem = torch.cuda.memory_allocated() / 1024**2
    # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
    # print(
    #     "autochunk now mem:%.2f max mem:%.2f"
    #     % (new_now_mem - now_mem, new_max_mem - now_mem)
    # )
oahzxl's avatar
oahzxl committed
42

oahzxl's avatar
init  
oahzxl committed
43
    # test forward
44
    model = model.cuda()
45
    with torch.no_grad():
46
47
        non_fx_out = model(node, pair, node_mask, pair_mask)
        fx_out = gm(node, pair, node_mask, pair_mask)
oahzxl's avatar
oahzxl committed
48

oahzxl's avatar
oahzxl committed
49
50
51
52
53
54
    assert torch.allclose(non_fx_out[0], fx_out[0],
                          atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
                              torch.abs(non_fx_out[0] - fx_out[0]))
    assert torch.allclose(non_fx_out[1], fx_out[1],
                          atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
                              torch.abs(non_fx_out[1] - fx_out[1]))
oahzxl's avatar
init  
oahzxl committed
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _build_openfold():
    model = EvoformerBlock(
        c_m=256,
        c_z=128,
        c_hidden_msa_att=32,
        c_hidden_opm=32,
        c_hidden_mul=128,
        c_hidden_pair_att=32,
        no_heads_msa=8,
        no_heads_pair=4,
        transition_n=4,
        msa_dropout=0.15,
        pair_dropout=0.15,
        inf=1e4,
        eps=1e-4,
        is_multimer=False,
    ).eval().cuda()
    return model


def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
oahzxl's avatar
oahzxl committed
78
    # launch colossalai
oahzxl's avatar
oahzxl committed
79
80
81
82
83
84
85
86
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )
oahzxl's avatar
init  
oahzxl committed
87
88

    # build model and input
89
    model = _build_openfold()
oahzxl's avatar
oahzxl committed
90
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
91
    node_mask = torch.randn(1, msa_len, pair_len).cuda()
oahzxl's avatar
oahzxl committed
92
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()
93
    pair_mask = torch.randn(1, pair_len, pair_len).cuda()
oahzxl's avatar
init  
oahzxl committed
94

95
96
    # trace the meta graph and setup codegen
    meta_graph = symbolic_trace(
oahzxl's avatar
oahzxl committed
97
98
        model,
        meta_args={
99
100
101
102
103
104
105
106
            "m": node.to(torch.device("meta")),
            "z": pair.to(torch.device("meta")),
            "msa_mask": node_mask.to(torch.device("meta")),
            "pair_mask": pair_mask.to(torch.device("meta")),
        },
        concrete_args={
            "chunk_size": None,
            "_mask_trans": True,
oahzxl's avatar
oahzxl committed
107
108
        },
    )
109
110
111
112
113
114
115
116
    interp = MetaInfoProp(meta_graph)
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"),
        MetaTensor(pair, fake_device="cuda:0"),
        MetaTensor(node_mask, fake_device="cuda:0"),
        MetaTensor(pair_mask, fake_device="cuda:0"),
    )
    # codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
oahzxl's avatar
init  
oahzxl committed
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    # trace and recompile
    # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
    graph = ColoTracer().trace(
        model,
        meta_args={
            "m": node.to(torch.device("meta")),
            "z": pair.to(torch.device("meta")),
            "msa_mask": node_mask.to(torch.device("meta")),
            "pair_mask": pair_mask.to(torch.device("meta")),
        },
        concrete_args={
            "chunk_size": None,
            "_mask_trans": True,
        },
    )
    # graph.set_codegen(codegen)
oahzxl's avatar
oahzxl committed
134
    gm = ColoGraphModule(model, graph)
oahzxl's avatar
init  
oahzxl committed
135
136
    gm.recompile()

oahzxl's avatar
oahzxl committed
137
138
139
    # assert we have inserted chunk
    code = graph.python_code("self").src
    assert "chunk_size" in code
oahzxl's avatar
oahzxl committed
140
    # print(code)
oahzxl's avatar
init  
oahzxl committed
141

142
    _test_fwd(model, gm, node, pair, node_mask, pair_mask)
oahzxl's avatar
init  
oahzxl committed
143
144
145
    gpc.destroy()


146
147
148
149
@pytest.mark.skipif(
    not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
    reason="torch version is lower than 1.12.0",
)
oahzxl's avatar
oahzxl committed
150
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
oahzxl's avatar
oahzxl committed
151
152
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
153
def test_evoformer_codegen(msa_len, pair_len, max_memory):
oahzxl's avatar
oahzxl committed
154
    run_func = partial(
155
        _test_evoformer_codegen,
oahzxl's avatar
oahzxl committed
156
157
158
159
160
        msa_len=msa_len,
        pair_len=pair_len,
        max_memory=max_memory,
    )
    mp.spawn(run_func, nprocs=1)
oahzxl's avatar
init  
oahzxl committed
161
162
163


if __name__ == "__main__":
164
    _test_evoformer_codegen(0, 32, 64, 25)