test_inference.py 1.97 KB
Newer Older
LuGY's avatar
LuGY committed
1
2
3
import os
import copy
import pytest
4
import torch
LuGY's avatar
LuGY committed
5
6
7
import pickle
import torch.multiprocessing as mp
from functools import partial
8
9
10
11
12
13

import fastfold
from fastfold.model.hub import AlphaFold
from fastfold.config import model_config
from fastfold.model.fastnn import set_chunk_size
from fastfold.utils import inject_fastnn
LuGY's avatar
LuGY committed
14
15
from fastfold.utils.import_weights import import_jax_weights_

16

LuGY's avatar
LuGY committed
17
18
19
20
21
22
@pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.parametrize('chunk_size', [None, 2])
@pytest.mark.parametrize('inplace', [False, True])
def test_state_dict(world_size, chunk_size, inplace):
    run_func = partial(run_dist, world_size=world_size, chunk_size=chunk_size, inplace=inplace)
    mp.spawn(run_func, nprocs=world_size)
23
24


LuGY's avatar
LuGY committed
25
26
27
28
29
def run_dist(rank, world_size, chunk_size, inplace):
    os.environ['RANK'] = str(rank)
    os.environ['LOCAL_RANK'] = str(rank)
    os.environ['WORLD_SIZE'] = str(world_size)
    # init distributed for Dynamic Axial Parallelism
30
    fastfold.distributed.init_dap()
LuGY's avatar
LuGY committed
31
32
    inference(chunk_size, inplace)

33
34


LuGY's avatar
LuGY committed
35
def inference(chunk_size, inplace):
36
37

    config = model_config('model_1')
LuGY's avatar
LuGY committed
38
39
    config.globals.chunk_size = chunk_size
    config.globals.inplace = False
40
    model = AlphaFold(config)
LuGY's avatar
LuGY committed
41
    import_jax_weights_(model, '/data/scratch/fastfold/weight.npz')
42
43
44
    model.eval()
    model.cuda()

LuGY's avatar
LuGY committed
45
46
47
48
    fastmodel = copy.deepcopy(model)
    fastmodel = inject_fastnn(fastmodel)
    fastmodel.eval()
    fastmodel.cuda()
49

LuGY's avatar
LuGY committed
50
51
52
53
    set_chunk_size(model.globals.chunk_size)
    batch = pickle.load(open('/data/scratch/fastfold/mono_batch.pkl', 'rb'))
    batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
    fastbatch = copy.deepcopy(batch)
54
55
56

    with torch.no_grad():
        out = model(batch)
LuGY's avatar
LuGY committed
57
58
        config.globals.inplace = inplace
        fastout = fastmodel(fastbatch)
59

LuGY's avatar
LuGY committed
60
61
    pos_dif = torch.max(torch.abs(fastout["final_atom_positions"] - out["final_atom_positions"]))
    assert pos_dif < 1.5, f"Test failed at chunk size: {chunk_size}, inplace: {inplace}. The position dif is {pos_dif}"