test_template_embedder.py 3.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import pytest
import pickle
import os
import copy
import torch.multiprocessing as mp
from functools import partial
import fastfold
from fastfold.config import model_config
from fastfold.model.fastnn.ops import set_chunk_size
from fastfold.model.hub import AlphaFold
from fastfold.utils.inject_fastnn import inject_fastnn
from fastfold.utils.import_weights import import_jax_weights_
from fastfold.utils.tensor_utils import tensor_tree_map
from fastfold.utils.test_utils import get_param_path, get_data_path


@pytest.fixture(scope="module")
def get_openfold_module_and_data():
    with torch.no_grad():
        config = model_config('model_1')
        config.globals.inplace = False
        target_module = AlphaFold(config)
        import_jax_weights_(target_module, get_param_path())
        
        fast_module = copy.deepcopy(target_module)
        fast_module = inject_fastnn(fast_module)
        fast_module = fast_module.template_embedder
        fast_module = fast_module.eval().cuda()
        
        target_module = target_module.template_embedder
        target_module = target_module.eval().cuda()
        
        batch = pickle.load(open(get_data_path(), 'rb'))
        fetch_cur_batch = lambda t: t[..., 0]
        feats = tensor_tree_map(fetch_cur_batch, batch)
        feats = {k: v.cuda() for k, v in feats.items() if k.startswith("template_")}

        seq_len = 33
        z = torch.randn((seq_len, seq_len, 128)).cuda()
        z_mask = torch.ones((seq_len, seq_len)).cuda().to(dtype=z.dtype)
        
        template_embeds = target_module(copy.deepcopy(feats), z, z_mask.to(dtype=z.dtype), 0, None)
        z_out = z + template_embeds["template_pair_embedding"]
    return fast_module, z_out, feats, z, z_mask


@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 32])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data): 
    run_func = partial(_test_template_embedder, world_size=world_size, chunk_size=chunk_size, 
                       inplace=inplace, get_openfold_module_and_data=get_openfold_module_and_data)
    mp.spawn(run_func, nprocs=world_size)


def _test_template_embedder(rank, world_size, chunk_size, inplace, get_openfold_module_and_data):
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    # init distributed for Dynamic Axial Parallelism
    fastfold.distributed.init_dap()    

    fast_module, z_out, feats, z, z_mask = get_openfold_module_and_data
    
    fast_module = copy.deepcopy(fast_module).cuda()
    template_feats = copy.deepcopy(feats)
    for k, v in template_feats.items():
        template_feats[k] = v.cuda()

    with torch.no_grad():
        set_chunk_size(chunk_size)
        if inplace:
            template_embeds = fast_module(copy.deepcopy(template_feats), copy.deepcopy(z).cuda(), z_mask.to(dtype=z.dtype).cuda(), 0, chunk_size, inplace=inplace)
            z_fast = template_embeds["template_pair_embedding"]
        else:
            template_embeds = fast_module(copy.deepcopy(template_feats), copy.deepcopy(z).cuda(), z_mask.to(dtype=z.dtype).cuda(), 0, chunk_size)
            z_fast = z.cuda() + template_embeds["template_pair_embedding"]

    error = torch.mean(torch.abs(z_out.cuda() - z_fast))
    assert error < 5e-4, f"Test z failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {error}"