Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
0e992fe5
Unverified
Commit
0e992fe5
authored
Aug 01, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 01, 2018
Browse files
use pytorch's CELU to simplify code (#42)
parent
97e3df07
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
31 deletions
+18
-31
examples/model.py
examples/model.py
+16
-28
tests/test_benchmark.py
tests/test_benchmark.py
+1
-1
torchani/models/neurochem_atomic_network.py
torchani/models/neurochem_atomic_network.py
+1
-2
No files found.
examples/model.py
View file @
0e992fe5
...
@@ -3,30 +3,18 @@ import torchani
...
@@ -3,30 +3,18 @@ import torchani
import
os
import
os
def
celu
(
x
,
alpha
):
def
atomic
():
return
torch
.
where
(
x
>
0
,
x
,
alpha
*
(
torch
.
exp
(
x
/
alpha
)
-
1
))
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
384
,
128
),
torch
.
nn
.
CELU
(
0.1
),
class
AtomicNetwork
(
torch
.
nn
.
Module
):
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
CELU
(
0.1
),
def
__init__
(
self
):
torch
.
nn
.
Linear
(
128
,
64
),
super
(
AtomicNetwork
,
self
).
__init__
()
torch
.
nn
.
CELU
(
0.1
),
self
.
output_length
=
1
torch
.
nn
.
Linear
(
64
,
1
)
self
.
layer1
=
torch
.
nn
.
Linear
(
384
,
128
)
)
self
.
layer2
=
torch
.
nn
.
Linear
(
128
,
128
)
model
.
output_length
=
1
self
.
layer3
=
torch
.
nn
.
Linear
(
128
,
64
)
return
model
self
.
layer4
=
torch
.
nn
.
Linear
(
64
,
1
)
def
forward
(
self
,
aev
):
y
=
aev
y
=
self
.
layer1
(
y
)
y
=
celu
(
y
,
0.1
)
y
=
self
.
layer2
(
y
)
y
=
celu
(
y
,
0.1
)
y
=
self
.
layer3
(
y
)
y
=
celu
(
y
,
0.1
)
y
=
self
.
layer4
(
y
)
return
y
def
get_or_create_model
(
filename
,
benchmark
=
False
,
def
get_or_create_model
(
filename
,
benchmark
=
False
,
...
@@ -37,10 +25,10 @@ def get_or_create_model(filename, benchmark=False,
...
@@ -37,10 +25,10 @@ def get_or_create_model(filename, benchmark=False,
reducer
=
torch
.
sum
,
reducer
=
torch
.
sum
,
benchmark
=
benchmark
,
benchmark
=
benchmark
,
per_species
=
{
per_species
=
{
'C'
:
A
tomic
Network
(),
'C'
:
a
tomic
(),
'H'
:
A
tomic
Network
(),
'H'
:
a
tomic
(),
'N'
:
A
tomic
Network
(),
'N'
:
a
tomic
(),
'O'
:
A
tomic
Network
(),
'O'
:
a
tomic
(),
})
})
class
Flatten
(
torch
.
nn
.
Module
):
class
Flatten
(
torch
.
nn
.
Module
):
...
...
tests/test_benchmark.py
View file @
0e992fe5
...
@@ -99,7 +99,7 @@ class TestBenchmark(unittest.TestCase):
...
@@ -99,7 +99,7 @@ class TestBenchmark(unittest.TestCase):
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
model
=
torchani
.
models
.
NeuroChemNNP
(
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
benchmark
=
True
)
aev_computer
.
species
,
benchmark
=
True
)
.
to
(
self
.
device
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
...
...
torchani/models/neurochem_atomic_network.py
View file @
0e992fe5
...
@@ -187,8 +187,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
...
@@ -187,8 +187,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
self
.
activation
=
lambda
x
:
torch
.
exp
(
-
x
*
x
)
self
.
activation
=
lambda
x
:
torch
.
exp
(
-
x
*
x
)
elif
activation
==
9
:
# CELU
elif
activation
==
9
:
# CELU
alpha
=
0.1
alpha
=
0.1
self
.
activation
=
lambda
x
:
torch
.
where
(
self
.
activation
=
lambda
x
:
torch
.
celu
(
x
,
alpha
)
x
>
0
,
x
,
alpha
*
(
torch
.
exp
(
x
/
alpha
)
-
1
))
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'Unexpected activation {}'
.
format
(
activation
))
'Unexpected activation {}'
.
format
(
activation
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment