Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
f50cc0b4
Unverified
Commit
f50cc0b4
authored
Aug 03, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 03, 2018
Browse files
remove output_length (#54)
parent
73e447f0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
10 additions
and
41 deletions
+10
-41
examples/model.py
examples/model.py
+0
-1
torchani/aev.py
torchani/aev.py
+3
-7
torchani/ignite/loss_metrics.py
torchani/ignite/loss_metrics.py
+1
-1
torchani/models/ani_model.py
torchani/models/ani_model.py
+3
-5
torchani/models/custom.py
torchani/models/custom.py
+2
-16
torchani/models/neurochem_atomic_network.py
torchani/models/neurochem_atomic_network.py
+0
-3
torchani/models/neurochem_nnp.py
torchani/models/neurochem_nnp.py
+1
-8
No files found.
examples/model.py
View file @
f50cc0b4
...
...
@@ -13,7 +13,6 @@ def atomic():
torch
.
nn
.
CELU
(
0.1
),
torch
.
nn
.
Linear
(
64
,
1
)
)
model
.
output_length
=
1
return
model
...
...
torchani/aev.py
View file @
f50cc0b4
...
...
@@ -401,13 +401,9 @@ class SortedAEV(AEVComputer):
Tensor of shape (conformations, atoms, pairs, present species,
present species) storing the mask for each pair.
"""
species_a
=
self
.
combinations
(
species_a
,
-
1
)
species_a1
,
species_a2
=
species_a
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
species_a1
,
species_a2
=
self
.
combinations
(
species_a
,
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
mask
=
mask_a1
*
mask_a2
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
(
mask
+
mask_rev
)
>
0
...
...
torchani/ignite/loss_metrics.py
View file @
f50cc0b4
...
...
@@ -52,7 +52,7 @@ class DictMetric(Metric):
def
MSELoss
(
key
,
per_atom
=
True
):
if
per_atom
:
return
_PerAtomDictLoss
(
key
,
torch
.
nn
.
MSELoss
(
reduc
e
=
False
))
return
_PerAtomDictLoss
(
key
,
torch
.
nn
.
MSELoss
(
reduc
tion
=
'none'
))
else
:
return
DictLoss
(
key
,
torch
.
nn
.
MSELoss
())
...
...
torchani/models/ani_model.py
View file @
f50cc0b4
...
...
@@ -10,8 +10,6 @@ class ANIModel(BenchmarkedModule):
----------
species : list
Chemical symbol of supported atom species.
output_length : int
The length of output vector.
suffixes : sequence
Different suffixes denote different models in an ensemble.
model_<X><suffix> : nn.Module
...
...
@@ -30,13 +28,12 @@ class ANIModel(BenchmarkedModule):
forward : total time for the forward pass
"""
def
__init__
(
self
,
species
,
suffixes
,
reducer
,
output_length
,
models
,
def
__init__
(
self
,
species
,
suffixes
,
reducer
,
models
,
benchmark
=
False
):
super
(
ANIModel
,
self
).
__init__
(
benchmark
)
self
.
species
=
species
self
.
suffixes
=
suffixes
self
.
reducer
=
reducer
self
.
output_length
=
output_length
for
i
in
models
:
setattr
(
self
,
i
,
models
[
i
])
...
...
@@ -72,6 +69,7 @@ class ANIModel(BenchmarkedModule):
for
s
in
species_dedup
:
begin
=
species
.
index
(
s
)
end
=
atoms
-
rev_species
.
index
(
s
)
part_atoms
=
end
-
begin
y
=
aev
[:,
begin
:
end
,
:].
flatten
(
0
,
1
)
def
apply_model
(
suffix
):
...
...
@@ -80,7 +78,7 @@ class ANIModel(BenchmarkedModule):
return
model_X
(
y
)
ys
=
[
apply_model
(
suffix
)
for
suffix
in
self
.
suffixes
]
y
=
sum
(
ys
)
/
len
(
ys
)
y
=
y
.
view
(
conformations
,
-
1
,
self
.
output_length
)
y
=
y
.
view
(
conformations
,
part_atoms
,
-
1
)
per_species_outputs
.
append
(
y
)
per_species_outputs
=
torch
.
cat
(
per_species_outputs
,
dim
=
1
)
...
...
torchani/models/custom.py
View file @
f50cc0b4
...
...
@@ -17,22 +17,8 @@ class CustomModel(ANIModel):
The desired `reducer` attribute.
"""
suffixes
=
[
''
]
output_length
=
None
models
=
{}
for
i
in
per_species
:
model_X
=
per_species
[
i
]
if
not
hasattr
(
model_X
,
'output_length'
):
raise
ValueError
(
'''atomic neural network must explicitly specify
output length'''
)
elif
output_length
is
None
:
output_length
=
model_X
.
output_length
elif
output_length
!=
model_X
.
output_length
:
raise
ValueError
(
'''output length of each atomic neural network must
match'''
)
models
[
'model_'
+
i
]
=
per_species
[
i
]
super
(
CustomModel
,
self
).
__init__
(
list
(
per_species
.
keys
()),
suffixes
,
reducer
,
output_length
,
models
,
benchmark
)
for
i
in
per_species
:
setattr
(
self
,
'model_'
+
i
,
per_species
[
i
])
reducer
,
models
,
benchmark
)
torchani/models/neurochem_atomic_network.py
View file @
f50cc0b4
...
...
@@ -14,8 +14,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
----------
layers : int
Number of layers.
output_length : int
The length of output vector
layerN : torch.nn.Linear
Linear model for each layer.
activation : function
...
...
@@ -202,7 +200,6 @@ class NeuroChemAtomicNetwork(torch.nn.Module):
raise
ValueError
(
'bad parameter shape'
)
wfn
=
os
.
path
.
join
(
dirname
,
wfn
)
bfn
=
os
.
path
.
join
(
dirname
,
bfn
)
self
.
output_length
=
out_size
self
.
_load_param_file
(
linear
,
in_size
,
out_size
,
wfn
,
bfn
)
def
_load_param_file
(
self
,
linear
,
in_size
,
out_size
,
wfn
,
bfn
):
...
...
torchani/models/neurochem_nnp.py
View file @
f50cc0b4
...
...
@@ -43,18 +43,11 @@ class NeuroChemNNP(ANIModel):
reducer
=
torch
.
sum
models
=
{}
output_length
=
None
for
network_dir
,
suffix
in
zip
(
network_dirs
,
suffixes
):
for
i
in
species
:
filename
=
os
.
path
.
join
(
network_dir
,
'ANN-{}.nnf'
.
format
(
i
))
model_X
=
NeuroChemAtomicNetwork
(
filename
)
if
output_length
is
None
:
output_length
=
model_X
.
output_length
elif
output_length
!=
model_X
.
output_length
:
raise
ValueError
(
'''output length of each atomic neural networt
must match'''
)
models
[
'model_'
+
i
+
suffix
]
=
model_X
super
(
NeuroChemNNP
,
self
).
__init__
(
species
,
suffixes
,
reducer
,
output_length
,
models
,
benchmark
)
models
,
benchmark
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment