common.py 2.21 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
2
3
import torchani
import torch
import os
4
import configs
Xiang Gao's avatar
Xiang Gao committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class Averager:

    def __init__(self):
        self.count = 0
        self.subtotal = 0

    def add(self, count, subtotal):
        self.count += count
        self.subtotal += subtotal

    def avg(self):
        return self.subtotal / self.count


def celu(x, alpha):
    return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1))


class AtomicNetwork(torch.nn.Module):

27
    def __init__(self, aev_computer):
Xiang Gao's avatar
Xiang Gao committed
28
        super(AtomicNetwork, self).__init__()
29
        self.aev_computer = aev_computer
Xiang Gao's avatar
Xiang Gao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        self.output_length = 1
        self.layer1 = torch.nn.Linear(384, 128).type(
            aev_computer.dtype).to(aev_computer.device)
        self.layer2 = torch.nn.Linear(128, 128).type(
            aev_computer.dtype).to(aev_computer.device)
        self.layer3 = torch.nn.Linear(128, 64).type(
            aev_computer.dtype).to(aev_computer.device)
        self.layer4 = torch.nn.Linear(64, 1).type(
            aev_computer.dtype).to(aev_computer.device)

    def forward(self, aev):
        y = aev
        y = self.layer1(y)
        y = celu(y, 0.1)
        y = self.layer2(y)
        y = celu(y, 0.1)
        y = self.layer3(y)
        y = celu(y, 0.1)
        y = self.layer4(y)
        return y


52
53
def get_or_create_model(filename, benchmark=False, device=configs.device):
    aev_computer = torchani.SortedAEV(benchmark=benchmark, device=device)
Xiang Gao's avatar
Xiang Gao committed
54
55
56
57
58
    model = torchani.ModelOnAEV(
        aev_computer,
        reducer=torch.sum,
        benchmark=benchmark,
        per_species={
59
60
61
62
            'C': AtomicNetwork(aev_computer),
            'H': AtomicNetwork(aev_computer),
            'N': AtomicNetwork(aev_computer),
            'O': AtomicNetwork(aev_computer),
Xiang Gao's avatar
Xiang Gao committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        })
    if os.path.isfile(filename):
        model.load_state_dict(torch.load(filename))
    else:
        torch.save(model.state_dict(), filename)
    return model


energy_shifter = torchani.EnergyShifter()

loss = torch.nn.MSELoss(size_average=False)


def evaluate(model, coordinates, energies, species):
    count = coordinates.shape[0]
    pred = model(coordinates, species).squeeze()
    pred = energy_shifter.add_sae(pred, species)
    squared_error = loss(pred, energies)
    return count, squared_error