Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
5bb66915
Unverified
Commit
5bb66915
authored
Aug 07, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 07, 2018
Browse files
fix sort_by_species (#57)
parent
e5439f3d
Changes
102
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
18 deletions
+14
-18
tools/generate-unit-test-expect.py
tools/generate-unit-test-expect.py
+13
-17
torchani/aev.py
torchani/aev.py
+1
-1
No files found.
tools/generate-unit-test-expect.py
View file @
5bb66915
...
...
@@ -7,31 +7,29 @@ import numpy
import
torchani
import
pickle
from
torchani
import
buildin_const_file
,
buildin_sae_file
,
\
buildin_network_dir
,
default_dtype
,
default_device
buildin_network_dir
import
torchani.pyanitools
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
conv_au_ev
=
27.21138505
class
NeuroChem
(
torchani
.
aev
_base
.
AEVComputer
):
class
NeuroChem
(
torchani
.
aev
.
AEVComputer
):
def
__init__
(
self
,
dtype
=
default_dtype
,
device
=
default_devic
e
,
const_file
=
buildin_const_file
,
sae_file
=
buildin_sae_file
,
def
__init__
(
self
,
const_file
=
buildin_const_fil
e
,
sae_file
=
buildin_sae_file
,
network_dir
=
buildin_network_dir
):
super
(
NeuroChem
,
self
).
__init__
(
False
,
dtype
,
device
,
const_file
)
super
(
NeuroChem
,
self
).
__init__
(
False
,
const_file
)
self
.
sae_file
=
sae_file
self
.
network_dir
=
network_dir
self
.
nc
=
pyNeuroChem
.
molecule
(
self
.
const_file
,
self
.
sae_file
,
self
.
network_dir
,
0
)
def
_get_radial_part
(
self
,
fullaev
):
radial_size
=
self
.
radial_length
return
fullaev
[:,
:,
:
radial_size
]
return
fullaev
[:,
:,
:
self
.
radial_length
]
def
_get_angular_part
(
self
,
fullaev
):
radial_size
=
self
.
radial_length
return
fullaev
[:,
:,
radial_size
:]
return
fullaev
[:,
:,
self
.
radial_length
:]
def
_per_conformation
(
self
,
coordinates
,
species
):
atoms
=
coordinates
.
shape
[
0
]
...
...
@@ -50,29 +48,27 @@ class NeuroChem (torchani.aev_base.AEVComputer):
coordinates
[
i
],
species
)
for
i
in
range
(
conformations
)]
aevs
,
energies
,
forces
=
zip
(
*
results
)
aevs
=
torch
.
from_numpy
(
numpy
.
stack
(
aevs
)).
type
(
self
.
dtype
).
to
(
self
.
device
)
self
.
EtaR
.
dtype
).
to
(
self
.
EtaR
.
device
)
energies
=
torch
.
from_numpy
(
numpy
.
stack
(
energies
)).
type
(
self
.
dtype
).
to
(
self
.
device
)
self
.
EtaR
.
dtype
).
to
(
self
.
EtaR
.
device
)
forces
=
torch
.
from_numpy
(
numpy
.
stack
(
forces
)).
type
(
self
.
dtype
).
to
(
self
.
device
)
self
.
EtaR
.
dtype
).
to
(
self
.
EtaR
.
device
)
return
self
.
_get_radial_part
(
aevs
),
\
self
.
_get_angular_part
(
aevs
),
\
energies
,
forces
aev
=
torchani
.
SortedAEV
(
device
=
torch
.
device
(
'cpu'
))
ncaev
=
NeuroChem
(
device
=
torch
.
device
(
'cpu'
))
ncaev
=
NeuroChem
().
to
(
torch
.
device
(
'cpu'
))
mol_count
=
0
for
i
in
[
1
,
2
,
3
,
4
]:
data_file
=
os
.
path
.
join
(
path
,
'../
tests/
dataset/ani_gdb_s0{}.h5'
.
format
(
i
))
path
,
'../dataset/ani_gdb_s0{}.h5'
.
format
(
i
))
adl
=
torchani
.
pyanitools
.
anidataloader
(
data_file
)
for
data
in
adl
:
coordinates
=
data
[
'coordinates'
][:
10
,
:]
coordinates
=
torch
.
from_numpy
(
coordinates
).
type
(
aev
.
dtype
)
coordinates
=
torch
.
from_numpy
(
coordinates
).
type
(
nc
aev
.
EtaR
.
dtype
)
species
=
data
[
'species'
]
coordinates
,
species
=
aev
.
sort_by_species
(
coordinates
,
species
)
smiles
=
''
.
join
(
data
[
'smiles'
])
radial
,
angular
,
energies
,
forces
=
ncaev
(
coordinates
,
species
)
pickleobj
=
(
coordinates
,
species
,
radial
,
angular
,
energies
,
forces
)
...
...
torchani/aev.py
View file @
5bb66915
...
...
@@ -151,7 +151,7 @@ class PrepareInput(torch.nn.Module):
new_tensors
=
[]
for
t
in
tensors
:
new_tensors
.
append
(
t
.
index_select
(
1
,
reverse
))
return
(
species
,
*
tensors
)
return
(
species
,
*
new_
tensors
)
def
forward
(
self
,
species_coordinates
):
species
,
coordinates
=
species_coordinates
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment