comp6.py 3.84 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import os
import torch
import torchani
from torchani.data._pyanitools import anidataloader
import argparse
import math
import tqdm


HARTREE2KCAL = 627.509

# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dir', help='Path to the COMP6 directory')
Gao, Xiang's avatar
Gao, Xiang committed
15
parser.add_argument('-b', '--batchatoms', type=int, default=4096,
Gao, Xiang's avatar
Gao, Xiang committed
16
17
18
19
20
21
22
                    help='Maximum number of ATOMs in each batch')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser = parser.parse_args()

# run benchmark
23
ani1x = torchani.models.ANI1x().to(parser.device)
Gao, Xiang's avatar
Gao, Xiang committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


def recursive_h5_files(base):
    inside = os.listdir(base)
    for i in inside:
        path = os.path.join(base, i)
        if os.path.isfile(path) and path.endswith(".h5"):
            yield from anidataloader(path)
        elif os.path.isdir(path):
            yield from recursive_h5_files(path)


def by_batch(species, coordinates, model):
    shape = species.shape
    batchsize = max(1, parser.batchatoms // shape[1])
    coordinates = coordinates.clone().detach().requires_grad_(True)
    species = torch.split(species, batchsize)
    coordinates = torch.split(coordinates, batchsize)
    energies = []
    forces = []
Gao, Xiang's avatar
Gao, Xiang committed
44
    for s, c in zip(species, coordinates):
Gao, Xiang's avatar
Gao, Xiang committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        _, e = model((s, c))
        f, = torch.autograd.grad(e.sum(), c)
        energies.append(e)
        forces.append(f)
    return torch.cat(energies).detach(), torch.cat(forces).detach()


class Averager:

    def __init__(self):
        self.count = 0
        self.cumsum = 0

    def update(self, new):
        assert len(new.shape) == 1
        self.count += new.shape[0]
        self.cumsum += new.sum().item()

    def compute(self):
        return self.cumsum / self.count


def relative_energies(energies):
    a, b = torch.combinations(energies, r=2).unbind(1)
    return a - b


def do_benchmark(model):
    dataset = recursive_h5_files(parser.dir)
    mae_averager_energy = Averager()
    mae_averager_relative_energy = Averager()
    mae_averager_force = Averager()
    rmse_averager_energy = Averager()
    rmse_averager_relative_energy = Averager()
    rmse_averager_force = Averager()
    for i in tqdm.tqdm(dataset, position=0, desc="dataset"):
        # read
82
        coordinates = torch.tensor(i['coordinates'], device=parser.device)
Gao, Xiang's avatar
Gao, Xiang committed
83
84
        species = model.species_to_tensor(i['species']) \
                       .unsqueeze(0).expand(coordinates.shape[0], -1)
85
86
        energies = torch.tensor(i['energies'], device=parser.device)
        forces = torch.tensor(i['forces'], device=parser.device)
Gao, Xiang's avatar
Gao, Xiang committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        # compute
        energies2, forces2 = by_batch(species, coordinates, model)
        ediff = energies - energies2
        relative_ediff = relative_energies(energies) - \
            relative_energies(energies2)
        fdiff = forces.flatten() - forces2.flatten()
        # update
        mae_averager_energy.update(ediff.abs())
        mae_averager_relative_energy.update(relative_ediff.abs())
        mae_averager_force.update(fdiff.abs())
        rmse_averager_energy.update(ediff ** 2)
        rmse_averager_relative_energy.update(relative_ediff ** 2)
        rmse_averager_force.update(fdiff ** 2)
    mae_energy = mae_averager_energy.compute() * HARTREE2KCAL
    rmse_energy = math.sqrt(rmse_averager_energy.compute()) * HARTREE2KCAL
    mae_relative_energy = mae_averager_relative_energy.compute() * HARTREE2KCAL
    rmse_relative_energy = math.sqrt(rmse_averager_relative_energy.compute()) \
        * HARTREE2KCAL
    mae_force = mae_averager_force.compute() * HARTREE2KCAL
    rmse_force = math.sqrt(rmse_averager_force.compute()) * HARTREE2KCAL
    print("Energy:", mae_energy, rmse_energy)
    print("Relative Energy:", mae_relative_energy, rmse_relative_energy)
    print("Forces:", mae_force, rmse_force)


do_benchmark(ani1x)