Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
eb090700
Unverified
Commit
eb090700
authored
Aug 22, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 22, 2018
Browse files
simplify handling of species to avoid unnecessary arguments (#74)
parent
ce21d224
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
32 deletions
+27
-32
examples/model.py
examples/model.py
+5
-7
examples/nnp_training.py
examples/nnp_training.py
+1
-1
examples/training-benchmark.py
examples/training-benchmark.py
+2
-1
torchani/aev.py
torchani/aev.py
+10
-9
torchani/models.py
torchani/models.py
+6
-12
torchani/neurochem.py
torchani/neurochem.py
+3
-2
No files found.
examples/model.py
View file @
eb090700
...
...
@@ -3,6 +3,10 @@ import torchani
import
os
consts
=
torchani
.
buildins
.
consts
aev_computer
=
torchani
.
buildins
.
aev_computer
def
atomic
():
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
384
,
128
),
...
...
@@ -17,13 +21,7 @@ def atomic():
def
get_or_create_model
(
filename
,
device
=
torch
.
device
(
'cpu'
)):
aev_computer
=
torchani
.
buildins
.
aev_computer
model
=
torchani
.
ANIModel
([
(
'C'
,
atomic
()),
(
'H'
,
atomic
()),
(
'N'
,
atomic
()),
(
'O'
,
atomic
()),
])
model
=
torchani
.
ANIModel
([
atomic
()
for
_
in
range
(
4
)])
class
Flatten
(
torch
.
nn
.
Module
):
...
...
examples/nnp_training.py
View file @
eb090700
...
...
@@ -54,7 +54,7 @@ start = timeit.default_timer()
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
training
,
validation
,
testing
=
torchani
.
training
.
load_or_create
(
parser
.
dataset_checkpoint
,
parser
.
batch_size
,
nnp
[
0
]
.
species
,
parser
.
dataset_checkpoint
,
parser
.
batch_size
,
model
.
consts
.
species
,
parser
.
dataset_path
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
...
...
examples/training-benchmark.py
View file @
eb090700
...
...
@@ -24,7 +24,8 @@ device = torch.device(parser.device)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
dataset
=
torchani
.
training
.
BatchedANIDataset
(
parser
.
dataset_path
,
nnp
[
0
].
species
,
parser
.
batch_size
,
device
=
device
,
parser
.
dataset_path
,
model
.
consts
.
species
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
torchani/aev.py
View file @
eb090700
...
...
@@ -42,11 +42,12 @@ class AEVComputer(torch.nn.Module):
The name of the file that stores constant.
Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
species :
list(str)
Chemical symbols
of supported atom types
num_
species :
int
Number
of supported atom types
"""
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
species
):
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
):
super
(
AEVComputer
,
self
).
__init__
()
self
.
register_buffer
(
'Rcr'
,
Rcr
)
self
.
register_buffer
(
'Rca'
,
Rca
)
...
...
@@ -60,7 +61,7 @@ class AEVComputer(torch.nn.Module):
self
.
register_buffer
(
'ShfA'
,
ShfA
.
view
(
1
,
1
,
-
1
,
1
))
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
species
=
species
self
.
num_
species
=
num_
species
def
radial_sublength
(
self
):
"""Returns the length of radial subaev of a single species"""
...
...
@@ -68,7 +69,7 @@ class AEVComputer(torch.nn.Module):
def
radial_length
(
self
):
"""Returns the length of full radial aev"""
return
len
(
self
.
species
)
*
self
.
radial_sublength
()
return
self
.
num_
species
*
self
.
radial_sublength
()
def
angular_sublength
(
self
):
"""Returns the length of angular subaev of a single species"""
...
...
@@ -77,8 +78,8 @@ class AEVComputer(torch.nn.Module):
def
angular_length
(
self
):
"""Returns the length of full angular aev"""
s
pecies
=
len
(
self
.
species
)
return
int
((
species
*
(
specie
s
+
1
))
/
2
)
*
self
.
angular_sublength
()
s
=
self
.
num_
species
return
(
s
*
(
s
+
1
))
/
/
2
*
self
.
angular_sublength
()
def
aev_length
(
self
):
"""Returns the length of full aev"""
...
...
@@ -266,7 +267,7 @@ class AEVComputer(torch.nn.Module):
"""Tensor of shape (conformations, atoms, neighbors) storing species
of neighbors."""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
len
(
self
.
species
)
,
device
=
self
.
EtaR
.
device
))
torch
.
arange
(
self
.
num_
species
,
device
=
self
.
EtaR
.
device
))
return
mask_r
def
compute_mask_a
(
self
,
species
,
indices_a
,
present_species
):
...
...
@@ -348,7 +349,7 @@ class AEVComputer(torch.nn.Module):
conformations
,
atoms
,
self
.
angular_sublength
(),
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
)
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
range
(
len
(
self
.
species
)
)
,
2
):
range
(
self
.
num_
species
),
2
):
if
s1
in
rev_indices
and
s2
in
rev_indices
:
i1
=
rev_indices
[
s1
]
i2
=
rev_indices
[
s2
]
...
...
torchani/models.py
View file @
eb090700
...
...
@@ -2,15 +2,14 @@ import torch
from
.
import
utils
class
ANIModel
(
torch
.
nn
.
Module
):
class
ANIModel
(
torch
.
nn
.
Module
List
):
def
__init__
(
self
,
mode
l
s
,
reducer
=
torch
.
sum
,
padding_fill
=
0
):
def
__init__
(
self
,
mod
ul
es
,
reducer
=
torch
.
sum
,
padding_fill
=
0
):
"""
Parameters
----------
models : (str, torch.nn.Module)
Models for all species. This must be a mapping where the key is
atomic symbol and the value is a module.
modules : seq(torch.nn.Module)
Modules for all species.
reducer : function
Function of (input, dim)->output that reduce the input tensor along
the given dimension to get an output tensor. This function will be
...
...
@@ -20,12 +19,9 @@ class ANIModel(torch.nn.Module):
padding_fill : float
Default value used to fill padding atoms
"""
super
(
ANIModel
,
self
).
__init__
()
self
.
species
=
[
s
for
s
,
_
in
models
]
super
(
ANIModel
,
self
).
__init__
(
modules
)
self
.
reducer
=
reducer
self
.
padding_fill
=
padding_fill
for
s
,
m
in
models
:
setattr
(
self
,
'model_'
+
s
,
m
)
def
forward
(
self
,
species_aev
):
"""Compute output from aev
...
...
@@ -56,11 +52,9 @@ class ANIModel(torch.nn.Module):
output
=
torch
.
full_like
(
species_
,
self
.
padding_fill
,
dtype
=
aev
.
dtype
)
for
i
in
present_species
:
s
=
self
.
species
[
i
]
model_X
=
getattr
(
self
,
'model_'
+
s
)
mask
=
(
species_
==
i
)
input
=
aev
.
index_select
(
0
,
mask
.
nonzero
().
squeeze
())
output
[
mask
]
=
model_X
(
input
).
squeeze
()
output
[
mask
]
=
self
[
i
]
(
input
).
squeeze
()
output
=
output
.
view_as
(
species
)
return
species
,
self
.
reducer
(
output
,
dim
=
1
)
...
...
torchani/neurochem.py
View file @
eb090700
...
...
@@ -33,6 +33,7 @@ class Constants(Mapping):
self
.
species
=
value
except
Exception
:
raise
ValueError
(
'unable to parse const file'
)
self
.
num_species
=
len
(
self
.
species
)
self
.
rev_species
=
{}
for
i
in
range
(
len
(
self
.
species
)):
s
=
self
.
species
[
i
]
...
...
@@ -47,7 +48,7 @@ class Constants(Mapping):
yield
'Zeta'
yield
'ShfA'
yield
'ShfZ'
yield
'species'
yield
'
num_
species'
def
__len__
(
self
):
return
8
...
...
@@ -232,7 +233,7 @@ def load_model(species, from_):
models
=
[]
for
i
in
species
:
filename
=
os
.
path
.
join
(
from_
,
'ANN-{}.nnf'
.
format
(
i
))
models
.
append
(
(
i
,
load_atomic_network
(
filename
))
)
models
.
append
(
load_atomic_network
(
filename
))
return
ANIModel
(
models
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment