test_msa_att_row.py 3.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import pytest
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.utils.test_utils import get_param_path
from fastfold.distributed.comm import gather, scatter


@pytest.fixture(scope="module")
def get_openfold_module_and_data():
    with torch.no_grad():
        config = model_config('model_1')
        config.globals.inplace = False
        target_module = AlphaFold(config)
        import_jax_weights_(target_module, get_param_path())

        fast_module = copy.deepcopy(target_module)
        fast_module = inject_fastnn(fast_module)
        fast_module = fast_module.evoformer.blocks[0].msa.MSARowAttentionWithPairBias.eval().cuda()
        target_module1 = target_module.evoformer.blocks[0].msa_att_row.eval().cuda()
        target_module2 = target_module.evoformer.blocks[0].msa_dropout_layer.eval().cuda()
        
        msa_len = 300
        seq_len = 300
        m = torch.randn((msa_len, seq_len, 256)).cuda()
        m_mask = torch.ones((msa_len, seq_len)).cuda().to(dtype=m.dtype)
        z = torch.randn((seq_len, seq_len, 128)).cuda()
        z_mask = torch.ones((seq_len, seq_len)).cuda().to(dtype=z.dtype)
        m_out = m + target_module2(target_module1(m, z=z, mask=m_mask, chunk_size=None))
        
        
    return m_out, m, z, m_mask, z_mask, fast_module


@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
def test_state_dict(world_size, chunk_size, get_openfold_module_and_data):
    run_func = partial(_test_msa_att_row, world_size=world_size, chunk_size=chunk_size, get_openfold_module_and_data=get_openfold_module_and_data)
    mp.spawn(run_func, nprocs=world_size)


def _test_msa_att_row(rank, world_size, chunk_size, get_openfold_module_and_data):
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    # init distributed for Dynamic Axial Parallelism
    fastfold.distributed.init_dap()

    m_out, m, z, m_mask, z_mask, fast_module = get_openfold_module_and_data
    fast_module = copy.deepcopy(fast_module).cuda()

    fast_m = copy.deepcopy(m.cuda()).unsqueeze(0)
    fast_z = copy.deepcopy(z.cuda()).unsqueeze(0)
    dap_size = gpc.get_world_size(ParallelMode.TENSOR)
    seq_length = z_mask.cuda().size(-1)
    padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
    fast_m = torch.nn.functional.pad(fast_m, (0, 0, 0, padding_size))
    fast_z = torch.nn.functional.pad(fast_z, (0, 0, 0, padding_size, 0, padding_size))
    fast_m = scatter(fast_m, dim=1)
    fast_z = scatter(fast_z, dim=1)
    fast_m_mask = copy.deepcopy(m_mask.cuda()).unsqueeze(0)
    fast_m_mask = torch.nn.functional.pad(fast_m_mask, (0, padding_size))
    
    with torch.no_grad():
        set_chunk_size(chunk_size)
        fast_m_mask = scatter(fast_m_mask.cuda(), dim=1)
        m_fast = fast_module(fast_m.cuda(), fast_z.cuda(), fast_m_mask)
        m_fast = m_fast.squeeze(0)
        m_fast = gather(m_fast, dim=0)
        m_fast = m_fast[:, :-padding_size, :]

    error = torch.max(torch.abs(m_out.cuda() - m_fast))
    assert error < 5e-5, f"Test m failed at chunk size: {chunk_size}. The position dif is {error}"