"tests/vscode:/vscode.git/clone" did not exist on "1db4ad4fcc76c2ad87ee9066ad8f7e4ccf4a7290"
Unverified Commit 0e992fe5 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

use pytorch's CELU to simplify code (#42)

parent 97e3df07
...@@ -3,30 +3,18 @@ import torchani ...@@ -3,30 +3,18 @@ import torchani
import os import os
def celu(x, alpha): def atomic():
return torch.where(x > 0, x, alpha * (torch.exp(x/alpha)-1)) model = torch.nn.Sequential(
torch.nn.Linear(384, 128),
torch.nn.CELU(0.1),
class AtomicNetwork(torch.nn.Module): torch.nn.Linear(128, 128),
torch.nn.CELU(0.1),
def __init__(self): torch.nn.Linear(128, 64),
super(AtomicNetwork, self).__init__() torch.nn.CELU(0.1),
self.output_length = 1 torch.nn.Linear(64, 1)
self.layer1 = torch.nn.Linear(384, 128) )
self.layer2 = torch.nn.Linear(128, 128) model.output_length = 1
self.layer3 = torch.nn.Linear(128, 64) return model
self.layer4 = torch.nn.Linear(64, 1)
def forward(self, aev):
y = aev
y = self.layer1(y)
y = celu(y, 0.1)
y = self.layer2(y)
y = celu(y, 0.1)
y = self.layer3(y)
y = celu(y, 0.1)
y = self.layer4(y)
return y
def get_or_create_model(filename, benchmark=False, def get_or_create_model(filename, benchmark=False,
...@@ -37,10 +25,10 @@ def get_or_create_model(filename, benchmark=False, ...@@ -37,10 +25,10 @@ def get_or_create_model(filename, benchmark=False,
reducer=torch.sum, reducer=torch.sum,
benchmark=benchmark, benchmark=benchmark,
per_species={ per_species={
'C': AtomicNetwork(), 'C': atomic(),
'H': AtomicNetwork(), 'H': atomic(),
'N': AtomicNetwork(), 'N': atomic(),
'O': AtomicNetwork(), 'O': atomic(),
}) })
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
......
...@@ -99,7 +99,7 @@ class TestBenchmark(unittest.TestCase): ...@@ -99,7 +99,7 @@ class TestBenchmark(unittest.TestCase):
dtype=self.dtype, device=self.device) dtype=self.dtype, device=self.device)
prepare = torchani.PrepareInput(aev_computer.species, self.device) prepare = torchani.PrepareInput(aev_computer.species, self.device)
model = torchani.models.NeuroChemNNP( model = torchani.models.NeuroChemNNP(
aev_computer.species, benchmark=True) aev_computer.species, benchmark=True).to(self.device)
run_module = torch.nn.Sequential(prepare, aev_computer, model) run_module = torch.nn.Sequential(prepare, aev_computer, model)
self._testModule(run_module, model, ['forward']) self._testModule(run_module, model, ['forward'])
......
...@@ -187,8 +187,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module): ...@@ -187,8 +187,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
self.activation = lambda x: torch.exp(-x*x) self.activation = lambda x: torch.exp(-x*x)
elif activation == 9: # CELU elif activation == 9: # CELU
alpha = 0.1 alpha = 0.1
self.activation = lambda x: torch.where( self.activation = lambda x: torch.celu(x, alpha)
x > 0, x, alpha * (torch.exp(x/alpha)-1))
else: else:
raise NotImplementedError( raise NotImplementedError(
'Unexpected activation {}'.format(activation)) 'Unexpected activation {}'.format(activation))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment