Unverified Commit 0e992fe5 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

use pytorch's CELU to simplify code (#42)

parent 97e3df07
......@@ -3,30 +3,18 @@ import torchani
import os
def celu(x, alpha):
return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1))
class AtomicNetwork(torch.nn.Module):
def __init__(self):
super(AtomicNetwork, self).__init__()
self.output_length = 1
self.layer1 = torch.nn.Linear(384, 128)
self.layer2 = torch.nn.Linear(128, 128)
self.layer3 = torch.nn.Linear(128, 64)
self.layer4 = torch.nn.Linear(64, 1)
def forward(self, aev):
y = aev
y = self.layer1(y)
y = celu(y, 0.1)
y = self.layer2(y)
y = celu(y, 0.1)
y = self.layer3(y)
y = celu(y, 0.1)
y = self.layer4(y)
return y
def atomic():
model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
torch.nn.Linear(128, 64),
torch.nn.CELU(0.1),
torch.nn.Linear(64, 1)
)
model.output_length = 1
return model
def get_or_create_model(filename, benchmark=False,
......@@ -37,10 +25,10 @@ def get_or_create_model(filename, benchmark=False,
reducer=torch.sum,
benchmark=benchmark,
per_species={
'C': AtomicNetwork(),
'H': AtomicNetwork(),
'N': AtomicNetwork(),
'O': AtomicNetwork(),
'C': atomic(),
'H': atomic(),
'N': atomic(),
'O': atomic(),
})
class Flatten(torch.nn.Module):
......
......@@ -99,7 +99,7 @@ class TestBenchmark(unittest.TestCase):
dtype=self.dtype, device=self.device)
prepare = torchani.PrepareInput(aev_computer.species, self.device)
model = torchani.models.NeuroChemNNP(
aev_computer.species, benchmark=True)
aev_computer.species, benchmark=True).to(self.device)
run_module = torch.nn.Sequential(prepare, aev_computer, model)
self._testModule(run_module, model, ['forward'])
......
......@@ -187,8 +187,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
self.activation = lambda x: torch.exp(-x*x)
elif activation == 9: # CELU
alpha = 0.1
self.activation = lambda x: torch.where(
x > 0, x, alpha * (torch.exp(x/alpha)-1))
self.activation = lambda x: torch.celu(x, alpha)
else:
raise NotImplementedError(
'Unexpected activation {}'.format(activation))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment