test_binding_affinity.py 1.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import dgl
import os
import shutil
import torch

from dgl.data.utils import _get_dgl_url, download, extract_archive

from dgllife.model.model_zoo.acnn import ACNN
from dgllife.utils.complex_to_graph import ACNN_graph_construction_and_featurization
from dgllife.utils.rdkit_utils import load_molecule

def remove_dir(dir):
    if os.path.isdir(dir):
        try:
            shutil.rmtree(dir)
        except OSError:
            pass

def test_acnn():
    remove_dir('tmp1')
    remove_dir('tmp2')

    url = _get_dgl_url('dgllife/example_mols.tar.gz')
    local_path = 'tmp1/example_mols.tar.gz'
    download(url, path=local_path)
    extract_archive(local_path, 'tmp2')

    pocket_mol, pocket_coords = load_molecule(
        'tmp2/example_mols/example.pdb', remove_hs=True)
    ligand_mol, ligand_coords = load_molecule(
        'tmp2/example_mols/example.pdbqt', remove_hs=True)

    remove_dir('tmp1')
    remove_dir('tmp2')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g1 = ACNN_graph_construction_and_featurization(ligand_mol,
                                                   pocket_mol,
                                                   ligand_coords,
                                                   pocket_coords)

    model = ACNN()
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])

    model = ACNN(hidden_sizes=[1, 2],
                 weight_init_stddevs=[1, 1],
                 dropouts=[0.1, 0.],
                 features_to_use=torch.tensor([6., 8.]),
                 radial=[[12.0], [0.0, 2.0], [4.0]])
    model.to(device)
    g1.to(device)
    assert model(g1).shape == torch.Size([1, 1])

    bg = dgl.batch_hetero([g1, g1])
    bg.to(device)
    assert model(bg).shape == torch.Size([2, 1])

if __name__ == '__main__':
    test_acnn()