test_evoformer_codegen.py 5.16 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
from functools import partial

oahzxl's avatar
init  
oahzxl committed
3
import pytest
oahzxl's avatar
oahzxl committed
4
import torch
oahzxl's avatar
oahzxl committed
5
import torch.fx
oahzxl's avatar
init  
oahzxl committed
6
import torch.multiprocessing as mp
oahzxl's avatar
oahzxl committed
7

8
9
10
11
12
13
try:
    from fastfold.model.nn.evoformer import EvoformerBlock
    HAS_REPO = True
except:
    HAS_REPO = False

oahzxl's avatar
init  
oahzxl committed
14
15
import colossalai
from colossalai.core import global_context as gpc
oahzxl's avatar
oahzxl committed
16
from colossalai.fx._compatibility import is_compatible_with_meta
oahzxl's avatar
oahzxl committed
17
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
oahzxl's avatar
init  
oahzxl committed
18
from colossalai.fx.graph_module import ColoGraphModule
oahzxl's avatar
oahzxl committed
19
20
21
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port

oahzxl's avatar
oahzxl committed
22
if CODEGEN_AVAILABLE and is_compatible_with_meta():
oahzxl's avatar
oahzxl committed
23
    from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
oahzxl's avatar
oahzxl committed
24
    from colossalai.fx.profiler import MetaTensor
25
    from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
oahzxl's avatar
oahzxl committed
26

oahzxl's avatar
oahzxl committed
27

28
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
oahzxl's avatar
oahzxl committed
29
    # for memory test
30
    # model = model.cuda()
oahzxl's avatar
oahzxl committed
31
32
33
34
35
    # torch.cuda.reset_peak_memory_stats()
    # now_mem = torch.cuda.memory_allocated() / 1024**2
    # with torch.no_grad():
    #     node1 = node.clone()
    #     pair1 = pair.clone()
36
37
38
    #     node_mask1 = node_mask.clone()
    #     pair_mask1 = pair_mask.clone()
    #     gm(node1, pair1, node_mask1, pair_mask1)
oahzxl's avatar
oahzxl committed
39
    # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
40
    # print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
oahzxl's avatar
oahzxl committed
41

oahzxl's avatar
init  
oahzxl committed
42
    # test forward
43
    model = model.cuda()
44
    with torch.no_grad():
45
46
        non_fx_out = model(node, pair, node_mask, pair_mask)
        fx_out = gm(node, pair, node_mask, pair_mask)
oahzxl's avatar
oahzxl committed
47

oahzxl's avatar
oahzxl committed
48
49
50
51
52
53
    assert torch.allclose(non_fx_out[0], fx_out[0],
                          atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
                              torch.abs(non_fx_out[0] - fx_out[0]))
    assert torch.allclose(non_fx_out[1], fx_out[1],
                          atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
                              torch.abs(non_fx_out[1] - fx_out[1]))
oahzxl's avatar
init  
oahzxl committed
54
55


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def _build_openfold():
    model = EvoformerBlock(
        c_m=256,
        c_z=128,
        c_hidden_msa_att=32,
        c_hidden_opm=32,
        c_hidden_mul=128,
        c_hidden_pair_att=32,
        no_heads_msa=8,
        no_heads_pair=4,
        transition_n=4,
        msa_dropout=0.15,
        pair_dropout=0.15,
        inf=1e4,
        eps=1e-4,
        is_multimer=False,
    ).eval().cuda()
    return model


def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
oahzxl's avatar
oahzxl committed
77
    # launch colossalai
oahzxl's avatar
oahzxl committed
78
79
80
81
82
83
84
85
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )
oahzxl's avatar
init  
oahzxl committed
86
87

    # build model and input
88
    model = _build_openfold()
oahzxl's avatar
oahzxl committed
89
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
90
    node_mask = torch.randn(1, msa_len, pair_len).cuda()
oahzxl's avatar
oahzxl committed
91
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()
92
    pair_mask = torch.randn(1, pair_len, pair_len).cuda()
oahzxl's avatar
init  
oahzxl committed
93

94
95
    # trace the meta graph and setup codegen
    meta_graph = symbolic_trace(
oahzxl's avatar
oahzxl committed
96
97
        model,
        meta_args={
98
99
100
101
102
103
104
105
            "m": node.to(torch.device("meta")),
            "z": pair.to(torch.device("meta")),
            "msa_mask": node_mask.to(torch.device("meta")),
            "pair_mask": pair_mask.to(torch.device("meta")),
        },
        concrete_args={
            "chunk_size": None,
            "_mask_trans": True,
oahzxl's avatar
oahzxl committed
106
107
        },
    )
108
109
110
111
112
113
114
    interp = MetaInfoProp(meta_graph)
    interp.propagate(
        MetaTensor(node, fake_device="cuda:0"),
        MetaTensor(pair, fake_device="cuda:0"),
        MetaTensor(node_mask, fake_device="cuda:0"),
        MetaTensor(pair_mask, fake_device="cuda:0"),
    )
115
    codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
oahzxl's avatar
init  
oahzxl committed
116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    # trace and recompile
    # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
    graph = ColoTracer().trace(
        model,
        meta_args={
            "m": node.to(torch.device("meta")),
            "z": pair.to(torch.device("meta")),
            "msa_mask": node_mask.to(torch.device("meta")),
            "pair_mask": pair_mask.to(torch.device("meta")),
        },
        concrete_args={
            "chunk_size": None,
            "_mask_trans": True,
        },
    )
132
    graph.set_codegen(codegen)
oahzxl's avatar
oahzxl committed
133
    gm = ColoGraphModule(model, graph)
oahzxl's avatar
init  
oahzxl committed
134
135
    gm.recompile()

oahzxl's avatar
oahzxl committed
136
137
    # assert we have inserted chunk
    code = graph.python_code("self").src
oahzxl's avatar
oahzxl committed
138
    # print(code)
139
    assert "chunk_result = None;  chunk_size = None;" in code
oahzxl's avatar
init  
oahzxl committed
140

141
    _test_fwd(model, gm, node, pair, node_mask, pair_mask)
oahzxl's avatar
init  
oahzxl committed
142
143
144
    gpc.destroy()


145
146
147
148
@pytest.mark.skipif(
    not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
    reason="torch version is lower than 1.12.0",
)
149
@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
oahzxl's avatar
oahzxl committed
150
151
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
152
def test_evoformer_codegen(msa_len, pair_len, max_memory):
oahzxl's avatar
oahzxl committed
153
    run_func = partial(
154
        _test_evoformer_codegen,
oahzxl's avatar
oahzxl committed
155
156
157
158
159
        msa_len=msa_len,
        pair_len=pair_len,
        max_memory=max_memory,
    )
    mp.spawn(run_func, nprocs=1)
oahzxl's avatar
init  
oahzxl committed
160
161
162


if __name__ == "__main__":
163
    _test_evoformer_codegen(0, 32, 64, 24)