test_simple_evoformer_search.py 3.28 KB
Newer Older
oahzxl's avatar
oahzxl committed
1
2
3
4
5
6
7
from functools import partial

import pytest
import torch
import torch.fx
import torch.multiprocessing as mp

8
9
10
11
12
13
try:
    from simple_evoformer import base_evoformer
    HAS_REPO = True
except:
    HAS_REPO = False

oahzxl's avatar
oahzxl committed
14
15
import colossalai
from colossalai.core import global_context as gpc
16
from colossalai.fx import symbolic_trace
oahzxl's avatar
oahzxl committed
17
from colossalai.fx._compatibility import is_compatible_with_meta
oahzxl's avatar
oahzxl committed
18
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
oahzxl's avatar
oahzxl committed
19
20
21
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port

oahzxl's avatar
oahzxl committed
22
if CODEGEN_AVAILABLE and is_compatible_with_meta():
oahzxl's avatar
oahzxl committed
23
    from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
oahzxl's avatar
oahzxl committed
24
    from colossalai.fx.profiler import MetaTensor
oahzxl's avatar
oahzxl committed
25

oahzxl's avatar
oahzxl committed
26
27
28
29
30
31

def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
    found_regions = [i["region"] for i in chunk_infos]

    if msa_len == 32 and pair_len == 64:
        if max_memory is None:
32
33
            target_regions = [(142, 154), (366, 373), (234, 283), (302, 351), (127, 134), (211, 228), (174, 191),
                              (161, 166), (198, 203), (7, 57)]
oahzxl's avatar
oahzxl committed
34
        elif max_memory == 20:
35
            target_regions = [(142, 154), (369, 373), (235, 269), (303, 351), (130, 131)]
oahzxl's avatar
oahzxl committed
36
37
38
39
40
41
42
43
        elif max_memory == 25:
            target_regions = [(144, 154), (369, 370)]
        elif max_memory == 30:
            target_regions = [(144, 154)]
        else:
            raise NotImplementedError()
    else:
        raise NotImplementedError()
oahzxl's avatar
oahzxl committed
44

45
46
47
48
    assert found_regions == target_regions, "found regions %s doesn't equal target regions %s" % (
        str(found_regions),
        str(target_regions),
    )
oahzxl's avatar
oahzxl committed
49
50


51
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
oahzxl's avatar
oahzxl committed
52
    # launch colossalai
oahzxl's avatar
oahzxl committed
53
54
55
56
57
58
59
60
61
62
    colossalai.launch(
        config={},
        rank=rank,
        world_size=1,
        host="localhost",
        port=free_port(),
        backend="nccl",
    )

    # build model and input
63
    model = base_evoformer().cuda()
oahzxl's avatar
oahzxl committed
64
65
66
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()

67
68
69
70
71
72
    meta_graph = symbolic_trace(model,
                                meta_args={
                                    "node": node.to(torch.device("meta")),
                                    "pair": pair.to(torch.device("meta")),
                                })    # must use symbolic_trace
    interp = MetaInfoProp(meta_graph)
oahzxl's avatar
oahzxl committed
73
    interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
74
    codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
oahzxl's avatar
oahzxl committed
75
76
77
78
79
80
    chunk_infos = codegen.chunk_infos
    assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)

    gpc.destroy()


81
82
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
                    reason="torch version is lower than 1.12.0")
oahzxl's avatar
oahzxl committed
83
84
85
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
@pytest.mark.parametrize("msa_len", [32])
@pytest.mark.parametrize("pair_len", [64])
86
def test_simple_evoformer_search(msa_len, pair_len, max_memory):
oahzxl's avatar
oahzxl committed
87
    run_func = partial(
88
        _test_simple_evoformer_search,
oahzxl's avatar
oahzxl committed
89
90
91
92
93
94
95
96
        msa_len=msa_len,
        pair_len=pair_len,
        max_memory=max_memory,
    )
    mp.spawn(run_func, nprocs=1)


if __name__ == "__main__":
97
    _test_simple_evoformer_search(0, 32, 64, 20)