Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
0e992fe5
"tests/vscode:/vscode.git/clone" did not exist on "1db4ad4fcc76c2ad87ee9066ad8f7e4ccf4a7290"
Unverified
Commit
0e992fe5
authored
Aug 01, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 01, 2018
Browse files
use pytorch's CELU to simplify code (#42)
parent
97e3df07
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
31 deletions
+18
-31
examples/model.py
examples/model.py
+16
-28
tests/test_benchmark.py
tests/test_benchmark.py
+1
-1
torchani/models/neurochem_atomic_network.py
torchani/models/neurochem_atomic_network.py
+1
-2
No files found.
examples/model.py
View file @
0e992fe5
...
@@ -3,30 +3,18 @@ import torchani
...
@@ -3,30 +3,18 @@ import torchani
import
os
import
os
def
celu
(
x
,
alpha
):
def
atomic
():
return
torch
.
where
(
x
>
0
,
x
,
alpha
*
(
torch
.
exp
(
x
/
alpha
)
-
1
))
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
384
,
128
),
torch
.
nn
.
CELU
(
0.1
),
class
AtomicNetwork
(
torch
.
nn
.
Module
):
torch
.
nn
.
Linear
(
128
,
128
),
torch
.
nn
.
CELU
(
0.1
),
def
__init__
(
self
):
torch
.
nn
.
Linear
(
128
,
64
),
super
(
AtomicNetwork
,
self
).
__init__
()
torch
.
nn
.
CELU
(
0.1
),
self
.
output_length
=
1
torch
.
nn
.
Linear
(
64
,
1
)
self
.
layer1
=
torch
.
nn
.
Linear
(
384
,
128
)
)
self
.
layer2
=
torch
.
nn
.
Linear
(
128
,
128
)
model
.
output_length
=
1
self
.
layer3
=
torch
.
nn
.
Linear
(
128
,
64
)
return
model
self
.
layer4
=
torch
.
nn
.
Linear
(
64
,
1
)
def
forward
(
self
,
aev
):
y
=
aev
y
=
self
.
layer1
(
y
)
y
=
celu
(
y
,
0.1
)
y
=
self
.
layer2
(
y
)
y
=
celu
(
y
,
0.1
)
y
=
self
.
layer3
(
y
)
y
=
celu
(
y
,
0.1
)
y
=
self
.
layer4
(
y
)
return
y
def
get_or_create_model
(
filename
,
benchmark
=
False
,
def
get_or_create_model
(
filename
,
benchmark
=
False
,
...
@@ -37,10 +25,10 @@ def get_or_create_model(filename, benchmark=False,
...
@@ -37,10 +25,10 @@ def get_or_create_model(filename, benchmark=False,
reducer
=
torch
.
sum
,
reducer
=
torch
.
sum
,
benchmark
=
benchmark
,
benchmark
=
benchmark
,
per_species
=
{
per_species
=
{
'C'
:
A
tomic
Network
(),
'C'
:
a
tomic
(),
'H'
:
A
tomic
Network
(),
'H'
:
a
tomic
(),
'N'
:
A
tomic
Network
(),
'N'
:
a
tomic
(),
'O'
:
A
tomic
Network
(),
'O'
:
a
tomic
(),
})
})
class
Flatten
(
torch
.
nn
.
Module
):
class
Flatten
(
torch
.
nn
.
Module
):
...
...
tests/test_benchmark.py
View file @
0e992fe5
...
@@ -99,7 +99,7 @@ class TestBenchmark(unittest.TestCase):
...
@@ -99,7 +99,7 @@ class TestBenchmark(unittest.TestCase):
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
model
=
torchani
.
models
.
NeuroChemNNP
(
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
benchmark
=
True
)
aev_computer
.
species
,
benchmark
=
True
)
.
to
(
self
.
device
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
...
...
torchani/models/neurochem_atomic_network.py
View file @
0e992fe5
...
@@ -187,8 +187,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
...
@@ -187,8 +187,7 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
self
.
activation
=
lambda
x
:
torch
.
exp
(
-
x
*
x
)
self
.
activation
=
lambda
x
:
torch
.
exp
(
-
x
*
x
)
elif
activation
==
9
:
# CELU
elif
activation
==
9
:
# CELU
alpha
=
0.1
alpha
=
0.1
self
.
activation
=
lambda
x
:
torch
.
where
(
self
.
activation
=
lambda
x
:
torch
.
celu
(
x
,
alpha
)
x
>
0
,
x
,
alpha
*
(
torch
.
exp
(
x
/
alpha
)
-
1
))
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
'Unexpected activation {}'
.
format
(
activation
))
'Unexpected activation {}'
.
format
(
activation
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment