Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
eb090700
Unverified
Commit
eb090700
authored
Aug 22, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 22, 2018
Browse files
simplify handling of species to avoid unnecessary arguments (#74)
parent
ce21d224
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
32 deletions
+27
-32
examples/model.py
examples/model.py
+5
-7
examples/nnp_training.py
examples/nnp_training.py
+1
-1
examples/training-benchmark.py
examples/training-benchmark.py
+2
-1
torchani/aev.py
torchani/aev.py
+10
-9
torchani/models.py
torchani/models.py
+6
-12
torchani/neurochem.py
torchani/neurochem.py
+3
-2
No files found.
examples/model.py
View file @
eb090700
...
@@ -3,6 +3,10 @@ import torchani
...
@@ -3,6 +3,10 @@ import torchani
import
os
import
os
consts
=
torchani
.
buildins
.
consts
aev_computer
=
torchani
.
buildins
.
aev_computer
def
atomic
():
def
atomic
():
model
=
torch
.
nn
.
Sequential
(
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
384
,
128
),
torch
.
nn
.
Linear
(
384
,
128
),
...
@@ -17,13 +21,7 @@ def atomic():
...
@@ -17,13 +21,7 @@ def atomic():
def
get_or_create_model
(
filename
,
device
=
torch
.
device
(
'cpu'
)):
def
get_or_create_model
(
filename
,
device
=
torch
.
device
(
'cpu'
)):
aev_computer
=
torchani
.
buildins
.
aev_computer
model
=
torchani
.
ANIModel
([
atomic
()
for
_
in
range
(
4
)])
model
=
torchani
.
ANIModel
([
(
'C'
,
atomic
()),
(
'H'
,
atomic
()),
(
'N'
,
atomic
()),
(
'O'
,
atomic
()),
])
class
Flatten
(
torch
.
nn
.
Module
):
class
Flatten
(
torch
.
nn
.
Module
):
...
...
examples/nnp_training.py
View file @
eb090700
...
@@ -54,7 +54,7 @@ start = timeit.default_timer()
...
@@ -54,7 +54,7 @@ start = timeit.default_timer()
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
training
,
validation
,
testing
=
torchani
.
training
.
load_or_create
(
training
,
validation
,
testing
=
torchani
.
training
.
load_or_create
(
parser
.
dataset_checkpoint
,
parser
.
batch_size
,
nnp
[
0
]
.
species
,
parser
.
dataset_checkpoint
,
parser
.
batch_size
,
model
.
consts
.
species
,
parser
.
dataset_path
,
device
=
device
,
parser
.
dataset_path
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
...
...
examples/training-benchmark.py
View file @
eb090700
...
@@ -24,7 +24,8 @@ device = torch.device(parser.device)
...
@@ -24,7 +24,8 @@ device = torch.device(parser.device)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
device
=
device
)
shift_energy
=
torchani
.
buildins
.
energy_shifter
shift_energy
=
torchani
.
buildins
.
energy_shifter
dataset
=
torchani
.
training
.
BatchedANIDataset
(
dataset
=
torchani
.
training
.
BatchedANIDataset
(
parser
.
dataset_path
,
nnp
[
0
].
species
,
parser
.
batch_size
,
device
=
device
,
parser
.
dataset_path
,
model
.
consts
.
species
,
parser
.
batch_size
,
device
=
device
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
transform
=
[
shift_energy
.
subtract_from_dataset
])
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
training
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
torchani/aev.py
View file @
eb090700
...
@@ -42,11 +42,12 @@ class AEVComputer(torch.nn.Module):
...
@@ -42,11 +42,12 @@ class AEVComputer(torch.nn.Module):
The name of the file that stores constant.
The name of the file that stores constant.
Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
Tensor storing constants.
species :
list(str)
num_
species :
int
Chemical symbols
of supported atom types
Number
of supported atom types
"""
"""
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
species
):
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
):
super
(
AEVComputer
,
self
).
__init__
()
super
(
AEVComputer
,
self
).
__init__
()
self
.
register_buffer
(
'Rcr'
,
Rcr
)
self
.
register_buffer
(
'Rcr'
,
Rcr
)
self
.
register_buffer
(
'Rca'
,
Rca
)
self
.
register_buffer
(
'Rca'
,
Rca
)
...
@@ -60,7 +61,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -60,7 +61,7 @@ class AEVComputer(torch.nn.Module):
self
.
register_buffer
(
'ShfA'
,
ShfA
.
view
(
1
,
1
,
-
1
,
1
))
self
.
register_buffer
(
'ShfA'
,
ShfA
.
view
(
1
,
1
,
-
1
,
1
))
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
species
=
species
self
.
num_
species
=
num_
species
def
radial_sublength
(
self
):
def
radial_sublength
(
self
):
"""Returns the length of radial subaev of a single species"""
"""Returns the length of radial subaev of a single species"""
...
@@ -68,7 +69,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -68,7 +69,7 @@ class AEVComputer(torch.nn.Module):
def
radial_length
(
self
):
def
radial_length
(
self
):
"""Returns the length of full radial aev"""
"""Returns the length of full radial aev"""
return
len
(
self
.
species
)
*
self
.
radial_sublength
()
return
self
.
num_
species
*
self
.
radial_sublength
()
def
angular_sublength
(
self
):
def
angular_sublength
(
self
):
"""Returns the length of angular subaev of a single species"""
"""Returns the length of angular subaev of a single species"""
...
@@ -77,8 +78,8 @@ class AEVComputer(torch.nn.Module):
...
@@ -77,8 +78,8 @@ class AEVComputer(torch.nn.Module):
def
angular_length
(
self
):
def
angular_length
(
self
):
"""Returns the length of full angular aev"""
"""Returns the length of full angular aev"""
s
pecies
=
len
(
self
.
species
)
s
=
self
.
num_
species
return
int
((
species
*
(
specie
s
+
1
))
/
2
)
*
self
.
angular_sublength
()
return
(
s
*
(
s
+
1
))
/
/
2
*
self
.
angular_sublength
()
def
aev_length
(
self
):
def
aev_length
(
self
):
"""Returns the length of full aev"""
"""Returns the length of full aev"""
...
@@ -266,7 +267,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -266,7 +267,7 @@ class AEVComputer(torch.nn.Module):
"""Tensor of shape (conformations, atoms, neighbors) storing species
"""Tensor of shape (conformations, atoms, neighbors) storing species
of neighbors."""
of neighbors."""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
len
(
self
.
species
)
,
device
=
self
.
EtaR
.
device
))
torch
.
arange
(
self
.
num_
species
,
device
=
self
.
EtaR
.
device
))
return
mask_r
return
mask_r
def
compute_mask_a
(
self
,
species
,
indices_a
,
present_species
):
def
compute_mask_a
(
self
,
species
,
indices_a
,
present_species
):
...
@@ -348,7 +349,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -348,7 +349,7 @@ class AEVComputer(torch.nn.Module):
conformations
,
atoms
,
self
.
angular_sublength
(),
conformations
,
atoms
,
self
.
angular_sublength
(),
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
)
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
)
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
range
(
len
(
self
.
species
)
)
,
2
):
range
(
self
.
num_
species
),
2
):
if
s1
in
rev_indices
and
s2
in
rev_indices
:
if
s1
in
rev_indices
and
s2
in
rev_indices
:
i1
=
rev_indices
[
s1
]
i1
=
rev_indices
[
s1
]
i2
=
rev_indices
[
s2
]
i2
=
rev_indices
[
s2
]
...
...
torchani/models.py
View file @
eb090700
...
@@ -2,15 +2,14 @@ import torch
...
@@ -2,15 +2,14 @@ import torch
from
.
import
utils
from
.
import
utils
class
ANIModel
(
torch
.
nn
.
Module
):
class
ANIModel
(
torch
.
nn
.
Module
List
):
def
__init__
(
self
,
mode
l
s
,
reducer
=
torch
.
sum
,
padding_fill
=
0
):
def
__init__
(
self
,
mod
ul
es
,
reducer
=
torch
.
sum
,
padding_fill
=
0
):
"""
"""
Parameters
Parameters
----------
----------
models : (str, torch.nn.Module)
modules : seq(torch.nn.Module)
Models for all species. This must be a mapping where the key is
Modules for all species.
atomic symbol and the value is a module.
reducer : function
reducer : function
Function of (input, dim)->output that reduce the input tensor along
Function of (input, dim)->output that reduce the input tensor along
the given dimension to get an output tensor. This function will be
the given dimension to get an output tensor. This function will be
...
@@ -20,12 +19,9 @@ class ANIModel(torch.nn.Module):
...
@@ -20,12 +19,9 @@ class ANIModel(torch.nn.Module):
padding_fill : float
padding_fill : float
Default value used to fill padding atoms
Default value used to fill padding atoms
"""
"""
super
(
ANIModel
,
self
).
__init__
()
super
(
ANIModel
,
self
).
__init__
(
modules
)
self
.
species
=
[
s
for
s
,
_
in
models
]
self
.
reducer
=
reducer
self
.
reducer
=
reducer
self
.
padding_fill
=
padding_fill
self
.
padding_fill
=
padding_fill
for
s
,
m
in
models
:
setattr
(
self
,
'model_'
+
s
,
m
)
def
forward
(
self
,
species_aev
):
def
forward
(
self
,
species_aev
):
"""Compute output from aev
"""Compute output from aev
...
@@ -56,11 +52,9 @@ class ANIModel(torch.nn.Module):
...
@@ -56,11 +52,9 @@ class ANIModel(torch.nn.Module):
output
=
torch
.
full_like
(
species_
,
self
.
padding_fill
,
output
=
torch
.
full_like
(
species_
,
self
.
padding_fill
,
dtype
=
aev
.
dtype
)
dtype
=
aev
.
dtype
)
for
i
in
present_species
:
for
i
in
present_species
:
s
=
self
.
species
[
i
]
model_X
=
getattr
(
self
,
'model_'
+
s
)
mask
=
(
species_
==
i
)
mask
=
(
species_
==
i
)
input
=
aev
.
index_select
(
0
,
mask
.
nonzero
().
squeeze
())
input
=
aev
.
index_select
(
0
,
mask
.
nonzero
().
squeeze
())
output
[
mask
]
=
model_X
(
input
).
squeeze
()
output
[
mask
]
=
self
[
i
]
(
input
).
squeeze
()
output
=
output
.
view_as
(
species
)
output
=
output
.
view_as
(
species
)
return
species
,
self
.
reducer
(
output
,
dim
=
1
)
return
species
,
self
.
reducer
(
output
,
dim
=
1
)
...
...
torchani/neurochem.py
View file @
eb090700
...
@@ -33,6 +33,7 @@ class Constants(Mapping):
...
@@ -33,6 +33,7 @@ class Constants(Mapping):
self
.
species
=
value
self
.
species
=
value
except
Exception
:
except
Exception
:
raise
ValueError
(
'unable to parse const file'
)
raise
ValueError
(
'unable to parse const file'
)
self
.
num_species
=
len
(
self
.
species
)
self
.
rev_species
=
{}
self
.
rev_species
=
{}
for
i
in
range
(
len
(
self
.
species
)):
for
i
in
range
(
len
(
self
.
species
)):
s
=
self
.
species
[
i
]
s
=
self
.
species
[
i
]
...
@@ -47,7 +48,7 @@ class Constants(Mapping):
...
@@ -47,7 +48,7 @@ class Constants(Mapping):
yield
'Zeta'
yield
'Zeta'
yield
'ShfA'
yield
'ShfA'
yield
'ShfZ'
yield
'ShfZ'
yield
'species'
yield
'
num_
species'
def
__len__
(
self
):
def
__len__
(
self
):
return
8
return
8
...
@@ -232,7 +233,7 @@ def load_model(species, from_):
...
@@ -232,7 +233,7 @@ def load_model(species, from_):
models
=
[]
models
=
[]
for
i
in
species
:
for
i
in
species
:
filename
=
os
.
path
.
join
(
from_
,
'ANN-{}.nnf'
.
format
(
i
))
filename
=
os
.
path
.
join
(
from_
,
'ANN-{}.nnf'
.
format
(
i
))
models
.
append
(
(
i
,
load_atomic_network
(
filename
))
)
models
.
append
(
load_atomic_network
(
filename
))
return
ANIModel
(
models
)
return
ANIModel
(
models
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment