comp6.py 4.08 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import torch
import torchani
from torchani.data._pyanitools import anidataloader
import argparse
import math
import tqdm


HARTREE2KCAL = 627.509
dtype = torch.float32

# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dir', help='Path to the COMP6 directory')
parser.add_argument('-b', '--batchatoms', type=int, default=512,
                    help='Maximum number of ATOMs in each batch')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser = parser.parse_args()

# run benchmark
ani1x = torchani.models.ANI1x().to(dtype).to(parser.device)


def recursive_h5_files(base):
    inside = os.listdir(base)
    for i in inside:
        path = os.path.join(base, i)
        if os.path.isfile(path) and path.endswith(".h5"):
            yield from anidataloader(path)
        elif os.path.isdir(path):
            yield from recursive_h5_files(path)


def by_batch(species, coordinates, model):
    shape = species.shape
    batchsize = max(1, parser.batchatoms // shape[1])
    coordinates = coordinates.clone().detach().requires_grad_(True)
    species = torch.split(species, batchsize)
    coordinates = torch.split(coordinates, batchsize)
    energies = []
    forces = []
    for s, c in tqdm.tqdm(zip(species, coordinates), total=len(species),
                          position=1, desc="batch of {}x{}".format(*shape)):
        _, e = model((s, c))
        f, = torch.autograd.grad(e.sum(), c)
        energies.append(e)
        forces.append(f)
    return torch.cat(energies).detach(), torch.cat(forces).detach()


class Averager:

    def __init__(self):
        self.count = 0
        self.cumsum = 0

    def update(self, new):
        assert len(new.shape) == 1
        self.count += new.shape[0]
        self.cumsum += new.sum().item()

    def compute(self):
        return self.cumsum / self.count


def relative_energies(energies):
    a, b = torch.combinations(energies, r=2).unbind(1)
    return a - b


def do_benchmark(model):
    dataset = recursive_h5_files(parser.dir)
    mae_averager_energy = Averager()
    mae_averager_relative_energy = Averager()
    mae_averager_force = Averager()
    rmse_averager_energy = Averager()
    rmse_averager_relative_energy = Averager()
    rmse_averager_force = Averager()
    for i in tqdm.tqdm(dataset, position=0, desc="dataset"):
        # read
        coordinates = torch.tensor(
            i['coordinates'], dtype=dtype, device=parser.device)
        species = model.species_to_tensor(i['species']) \
                       .unsqueeze(0).expand(coordinates.shape[0], -1)
        energies = torch.tensor(i['energies'], dtype=dtype,
                                device=parser.device)
        forces = torch.tensor(i['forces'], dtype=dtype,
                              device=parser.device)
        # compute
        energies2, forces2 = by_batch(species, coordinates, model)
        ediff = energies - energies2
        relative_ediff = relative_energies(energies) - \
            relative_energies(energies2)
        fdiff = forces.flatten() - forces2.flatten()
        # update
        mae_averager_energy.update(ediff.abs())
        mae_averager_relative_energy.update(relative_ediff.abs())
        mae_averager_force.update(fdiff.abs())
        rmse_averager_energy.update(ediff ** 2)
        rmse_averager_relative_energy.update(relative_ediff ** 2)
        rmse_averager_force.update(fdiff ** 2)
    mae_energy = mae_averager_energy.compute() * HARTREE2KCAL
    rmse_energy = math.sqrt(rmse_averager_energy.compute()) * HARTREE2KCAL
    mae_relative_energy = mae_averager_relative_energy.compute() * HARTREE2KCAL
    rmse_relative_energy = math.sqrt(rmse_averager_relative_energy.compute()) \
        * HARTREE2KCAL
    mae_force = mae_averager_force.compute() * HARTREE2KCAL
    rmse_force = math.sqrt(rmse_averager_force.compute()) * HARTREE2KCAL
    print("Energy:", mae_energy, rmse_energy)
    print("Relative Energy:", mae_relative_energy, rmse_relative_energy)
    print("Forces:", mae_force, rmse_force)


do_benchmark(ani1x)