test_inference.py 2.77 KB
Newer Older
LuGY's avatar
LuGY committed
1
2
3
import os
import copy
import pytest
4
import torch
LuGY's avatar
LuGY committed
5
6
7
import pickle
import torch.multiprocessing as mp
from functools import partial
8
9
10
11
12

import fastfold
from fastfold.model.hub import AlphaFold
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
13
from fastfold.utils.inject_fastnn import inject_fastnn
LuGY's avatar
LuGY committed
14
from fastfold.utils.import_weights import import_jax_weights_
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from fastfold.utils.test_utils import get_data_path, get_param_path


@pytest.fixture(scope="module")
def get_openfold_module_and_data():
    config = model_config('model_1')
    config.globals.inplace = False
    model = AlphaFold(config)
    import_jax_weights_(model, get_param_path())
    model.eval().cuda()
    batch = pickle.load(open(get_data_path(), 'rb'))
    batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
    with torch.no_grad():
        out = model(batch)
        
    fastmodel = copy.deepcopy(model)
    fastmodel = inject_fastnn(fastmodel)
    fastmodel.eval().cuda()
    return model, out, fastmodel
LuGY's avatar
LuGY committed
34

35

LuGY's avatar
LuGY committed
36
@pytest.mark.parametrize('world_size', [1, 2])
37
@pytest.mark.parametrize('chunk_size', [None, 32])
LuGY's avatar
LuGY committed
38
@pytest.mark.parametrize('inplace', [False, True])
39
40
def test_state_dict(world_size, chunk_size, inplace, get_openfold_module_and_data):
    run_func = partial(run_dist, world_size=world_size, chunk_size=chunk_size, inplace=inplace, model=get_openfold_module_and_data)
LuGY's avatar
LuGY committed
41
    mp.spawn(run_func, nprocs=world_size)
42
43


44
def run_dist(rank, world_size, chunk_size, inplace, model):
LuGY's avatar
LuGY committed
45
46
47
48
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    # init distributed for Dynamic Axial Parallelism
49
    fastfold.distributed.init_dap()
50
    inference(chunk_size, inplace, model)
LuGY's avatar
LuGY committed
51

52

53
def inference(chunk_size, inplace, get_openfold_module_and_data):
54

55
    model, out, fastmodel = get_openfold_module_and_data
56

57
58
    model.globals.chunk_size = chunk_size
    model.globals.inplace = inplace
59

60
61
62
63
64
65
    fastmodel = copy.deepcopy(fastmodel).cuda()

    fastmodel.structure_module.default_frames = fastmodel.structure_module.default_frames.cuda()
    fastmodel.structure_module.group_idx = fastmodel.structure_module.group_idx.cuda()
    fastmodel.structure_module.atom_mask = fastmodel.structure_module.atom_mask.cuda()
    fastmodel.structure_module.lit_positions = fastmodel.structure_module.lit_positions.cuda()
66

LuGY's avatar
LuGY committed
67
    set_chunk_size(model.globals.chunk_size)
68
    batch = pickle.load(open(get_data_path(), 'rb'))
LuGY's avatar
LuGY committed
69
    batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
70
71

    with torch.no_grad():
72
        fastout = fastmodel(batch)
73

74
75
    pos_dif = torch.max(torch.abs(fastout["final_atom_positions"] - out["final_atom_positions"].cuda()))
    assert pos_dif < 5e-4, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {pos_dif}"