".github/vscode:/vscode.git/clone" did not exist on "e2d496e8e42b08b829f358880368228dc2b74a38"
inference-benchmark.py 2.75 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
import torchani
import torch
import os
import timeit


path = os.path.dirname(os.path.realpath(__file__))

# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('filename',
                    help='Path to the xyz file.')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser = parser.parse_args()

# set up benchmark
device = torch.device(parser.device)
buildins = torchani.neurochem.Buildins()
nnp = torch.nn.Sequential(
    buildins.aev_computer,
    buildins.models[0],
    buildins.energy_shifter
).to(device)


# load XYZ files
class XYZ:

    def __init__(self, filename):
        with open(filename, 'r') as f:
            lines = f.readlines()

        # parse lines
        self.mols = []
        atom_count = None
        species = []
        coordinates = []
        state = 'ready'
        for i in lines:
            i = i.strip()
            if state == 'ready':
                atom_count = int(i)
                state = 'comment'
            elif state == 'comment':
                state = 'atoms'
            else:
                s, x, y, z = i.split()
                x, y, z = float(x), float(y), float(z)
                species.append(s)
                coordinates.append([x, y, z])
                atom_count -= 1
                if atom_count == 0:
                    state = 'ready'
                    species = buildins.consts.species_to_tensor(species) \
                                      .to(device)
                    coordinates = torch.tensor(coordinates, device=device)
                    self.mols.append((species, coordinates))
                    coordinates = []
                    species = []

    def __len__(self):
        return len(self.mols)

    def __getitem__(self, i):
        return self.mols[i]


xyz = XYZ(parser.filename)

print(len(xyz), 'conformations')
print()

# test batch mode
print('[Batch mode]')
species, coordinates = torch.utils.data.dataloader.default_collate(list(xyz))
coordinates = torch.tensor(coordinates, requires_grad=True)
start = timeit.default_timer()
energies = nnp((species, coordinates))[1]
mid = timeit.default_timer()
print('Energy time:', mid - start)
force = -torch.autograd.grad(energies.sum(), coordinates)[0]
print('Force time:', timeit.default_timer() - mid)
print()

# test single mode
print('[Single mode]')
start = timeit.default_timer()
for species, coordinates in xyz:
    species = species.unsqueeze(0)
    coordinates = torch.tensor(coordinates.unsqueeze(0), requires_grad=True)
    energies = nnp((species, coordinates))[1]
    force = -torch.autograd.grad(energies.sum(), coordinates)[0]
print('Time:', timeit.default_timer() - start)