"applications/ColossalChat/coati/ray/lora_constructor.py" did not exist on "16bf4c022150fea303d23437b0190c46204e722c"
test_extramsa_block.py 2.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from functools import partial
from typing import Dict, List, Tuple

import pytest
import torch
import torch.fx
import torch.multiprocessing as mp

try:
    from fastfold.model.nn.evoformer import ExtraMSABlock
    HAS_REPO = True
except:
    HAS_REPO = False
from test_alphafold_utils import run_test

from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE


def get_model():
    model = ExtraMSABlock(
        c_m=256,
        c_z=128,
        c_hidden_msa_att=32,
        c_hidden_opm=32,
        c_hidden_mul=128,
        c_hidden_pair_att=32,
        no_heads_msa=8,
        no_heads_pair=4,
        transition_n=4,
        msa_dropout=0.15,
        pair_dropout=0.15,
        inf=1e4,
        eps=1e-4,
        ckpt=False,
        is_multimer=False,
    ).eval().cuda()
    return model


def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
    node = torch.randn(1, msa_len, pair_len, 256).cuda()
    node_mask = torch.randn(1, msa_len, pair_len).cuda()
    pair = torch.randn(1, pair_len, pair_len, 128).cuda()
    pair_mask = torch.randn(1, pair_len, pair_len).cuda()

    meta_args = [
        ("m", node),
        ("z", pair),
        ("msa_mask", node_mask),
        ("pair_mask", pair_mask),
    ]
    concrete_args = [("chunk_size", None), ("_chunk_logits", 1024)]
    return meta_args, concrete_args


def get_chunk_target() -> Dict:
    return {
        None: [(126, 131), (227, 245), (272, 297), (310, 317), (105, 112), (152, 160), (193, 201), (249, 250),
               (33, 46)],
        20: [(126, 131), (238, 245), (283, 290), (313, 314), (108, 109), (35, 46)],
        24: [(126, 131)],
    }


@pytest.mark.skipif(
    not (AUTOCHUNK_AVAILABLE and HAS_REPO),
    reason="torch version is lower than 1.12.0",
)
@pytest.mark.parametrize("max_memory", [None, 20, 24])
@pytest.mark.parametrize("data_args", [(32, 64)])    # (msa_len, pair_len)
def test_extramsa_block(data_args, max_memory):
    run_func = partial(
        run_test,
        data_args=data_args,
        max_memory=max_memory,
        get_model=get_model,
        get_data=get_data,
        print_code=False,
        print_mem=False,
        print_progress=False,
    )
    mp.spawn(run_func, nprocs=1)


if __name__ == "__main__":
    run_test(
        rank=0,
        data_args=(32, 64),
        max_memory=20,
        get_model=get_model,
        get_data=get_data,
        get_chunk_target=get_chunk_target,
        print_code=False,
        print_mem=False,
        print_progress=False,
    )