Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
d3ae0788
Unverified
Commit
d3ae0788
authored
Aug 02, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 02, 2018
Browse files
remove explicit device and dtype (#44)
parent
8c493a6e
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
78 additions
and
103 deletions
+78
-103
examples/energy_force.py
examples/energy_force.py
+2
-4
examples/model.py
examples/model.py
+3
-3
examples/nnp_training.py
examples/nnp_training.py
+4
-2
examples/training-benchmark.py
examples/training-benchmark.py
+4
-2
tests/test_aev.py
tests/test_aev.py
+3
-4
tests/test_batch.py
tests/test_batch.py
+3
-6
tests/test_benchmark.py
tests/test_benchmark.py
+8
-14
tests/test_energies.py
tests/test_energies.py
+3
-6
tests/test_ensemble.py
tests/test_ensemble.py
+2
-2
tests/test_forces.py
tests/test_forces.py
+3
-6
tests/test_ignite.py
tests/test_ignite.py
+3
-7
torchani/__init__.py
torchani/__init__.py
+2
-3
torchani/aev.py
torchani/aev.py
+36
-38
torchani/data.py
torchani/data.py
+2
-2
torchani/env.py
torchani/env.py
+0
-4
No files found.
examples/energy_force.py
View file @
d3ae0788
...
...
@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
sae_file
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat'
)
# noqa: E501
network_dir
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/train'
)
# noqa: E501
aev_computer
=
torchani
.
SortedAEV
(
const_file
=
const_file
,
device
=
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
aev_computer
=
torchani
.
SortedAEV
(
const_file
=
const_file
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nn
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
from_
=
network_dir
,
ensemble
=
8
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nn
)
...
...
@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[
-
0.66518241
,
-
0.84461308
,
0.20759389
],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
dtype
=
aev_computer
.
dtype
,
device
=
aev_computer
.
device
,
requires_grad
=
True
)
species
=
[
'C'
,
'H'
,
'H'
,
'H'
,
'H'
]
...
...
examples/model.py
View file @
d3ae0788
...
...
@@ -18,9 +18,9 @@ def atomic():
def
get_or_create_model
(
filename
,
benchmark
=
False
,
device
=
torch
ani
.
default_device
):
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
benchmark
,
device
=
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
device
=
torch
.
device
(
'cpu'
)
):
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
benchmark
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
model
=
torchani
.
models
.
CustomModel
(
reducer
=
torch
.
sum
,
benchmark
=
benchmark
,
...
...
examples/nnp_training.py
View file @
d3ae0788
...
...
@@ -8,6 +8,8 @@ import timeit
import
tensorboardX
import
math
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
chunk_size
=
256
batch_chunks
=
4
dataset_path
=
sys
.
argv
[
1
]
...
...
@@ -20,11 +22,11 @@ start = timeit.default_timer()
shift_energy
=
torchani
.
EnergyShifter
()
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
dataset_checkpoint
,
dataset_path
,
chunk_size
,
dataset_checkpoint
,
dataset_path
,
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
training
=
torchani
.
data
.
dataloader
(
training
,
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
model_checkpoint
)
nnp
=
model
.
get_or_create_model
(
model_checkpoint
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
examples/training-benchmark.py
View file @
d3ae0788
...
...
@@ -6,15 +6,17 @@ import timeit
import
model
import
tqdm
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
chunk_size
=
256
batch_chunks
=
4
dataset_path
=
sys
.
argv
[
1
]
shift_energy
=
torchani
.
EnergyShifter
()
dataset
=
torchani
.
data
.
ANIDataset
(
dataset_path
,
chunk_size
,
dataset_path
,
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
tests/test_aev.py
View file @
d3ae0788
...
...
@@ -10,12 +10,11 @@ N = 97
class
TestAEV
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
):
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
torch
.
device
(
'cpu'
))
def
setUp
(
self
):
aev_computer
=
torchani
.
SortedAEV
()
self
.
radial_length
=
aev_computer
.
radial_length
self
.
aev
=
torch
.
nn
.
Sequential
(
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
),
torchani
.
PrepareInput
(
aev_computer
.
species
),
aev_computer
)
self
.
tolerance
=
1e-5
...
...
tests/test_batch.py
View file @
d3ae0788
...
...
@@ -12,17 +12,14 @@ if sys.version_info.major >= 3:
path
=
os
.
path
.
join
(
path
,
'../dataset'
)
chunksize
=
32
batch_chunks
=
32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestBatch
(
unittest
.
TestCase
):
def
testBatchLoadAndInference
(
self
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
device
=
device
)
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
model
)
...
...
tests/test_benchmark.py
View file @
d3ae0788
...
...
@@ -6,14 +6,10 @@ import copy
class
TestBenchmark
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
device
=
torchani
.
default_device
):
self
.
dtype
=
dtype
self
.
device
=
device
def
setUp
(
self
):
self
.
conformations
=
100
self
.
species
=
list
(
'HHCCNNOO'
)
self
.
coordinates
=
torch
.
randn
(
self
.
conformations
,
8
,
3
,
dtype
=
dtype
,
device
=
device
)
self
.
coordinates
=
torch
.
randn
(
self
.
conformations
,
8
,
3
)
self
.
count
=
100
def
_testModule
(
self
,
run_module
,
result_module
,
asserts
):
...
...
@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase):
self
.
assertEqual
(
result_module
.
timers
[
i
],
0
)
def
testAEV
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
True
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
True
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
)
self
.
_testModule
(
run_module
,
aev_computer
,
[
'terms and indices>radial terms'
,
...
...
@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase):
])
def
testANIModel
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
benchmark
=
True
).
to
(
self
.
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
benchmark
=
True
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
...
...
tests/test_energies.py
View file @
d3ae0788
...
...
@@ -11,13 +11,10 @@ N = 97
class
TestEnergies
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
device
=
torchani
.
default_device
):
def
setUp
(
self
):
self
.
tolerance
=
5e-5
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
torch
.
device
(
'cpu'
))
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
...
...
tests/test_ensemble.py
View file @
d3ae0788
...
...
@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase):
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
n
=
torchani
.
buildin_ensemble
prefix
=
torchani
.
buildin_model_prefix
aev
=
torchani
.
SortedAEV
(
device
=
torch
.
device
(
'cpu'
)
)
prepare
=
torchani
.
PrepareInput
(
aev
.
species
,
aev
.
device
)
aev
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev
.
species
)
ensemble
=
torchani
.
models
.
NeuroChemNNP
(
aev
.
species
,
ensemble
=
True
)
ensemble
=
torch
.
nn
.
Sequential
(
prepare
,
aev
,
ensemble
)
models
=
[
torchani
.
models
.
...
...
tests/test_forces.py
View file @
d3ae0788
...
...
@@ -10,13 +10,10 @@ N = 97
class
TestForce
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
device
=
torchani
.
default_device
):
def
setUp
(
self
):
self
.
tolerance
=
1e-5
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
torch
.
device
(
'cpu'
))
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
...
...
tests/test_ignite.py
View file @
d3ae0788
...
...
@@ -13,21 +13,17 @@ if sys.version_info.major >= 3:
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
chunksize
=
4
threshold
=
1e-5
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestIgnite
(
unittest
.
TestCase
):
def
testIgnite
(
self
):
shift_energy
=
torchani
.
EnergyShifter
()
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
path
,
chunksize
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
loader
=
torchani
.
data
.
dataloader
(
ds
,
1
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
class
Flatten
(
torch
.
nn
.
Module
):
...
...
torchani/__init__.py
View file @
d3ae0788
...
...
@@ -4,10 +4,9 @@ from . import data
from
.
import
ignite
from
.aev
import
SortedAEV
,
PrepareInput
from
.env
import
buildin_const_file
,
buildin_sae_file
,
buildin_network_dir
,
\
buildin_model_prefix
,
buildin_ensemble
,
default_dtype
,
default_device
buildin_model_prefix
,
buildin_ensemble
__all__
=
[
'PrepareInput'
,
'SortedAEV'
,
'EnergyShifter'
,
'models'
,
'data'
,
'ignite'
,
'buildin_const_file'
,
'buildin_sae_file'
,
'buildin_network_dir'
,
'buildin_model_prefix'
,
'buildin_ensemble'
,
'default_dtype'
,
'default_device'
]
'buildin_model_prefix'
,
'buildin_ensemble'
]
torchani/aev.py
View file @
d3ae0788
import
torch
import
itertools
import
math
from
.env
import
buildin_const_file
,
default_dtype
,
default_device
from
.env
import
buildin_const_file
from
.benchmarked
import
BenchmarkedModule
class
AEVComputer
(
BenchmarkedModule
):
__constants__
=
[
'Rcr'
,
'Rca'
,
'dtype'
,
'device'
,
'radial_sublength'
,
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
__constants__
=
[
'Rcr'
,
'Rca'
,
'radial_sublength'
,
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
"""Base class of various implementations of AEV computer
...
...
@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule):
----------
benchmark : boolean
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
The name of the original file that stores constant.
Rcr, Rca : float
...
...
@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule):
The length of full aev
"""
def
__init__
(
self
,
benchmark
=
False
,
dtype
=
default_dtype
,
device
=
default_device
,
const_file
=
buildin_const_file
):
def
__init__
(
self
,
benchmark
=
False
,
const_file
=
buildin_const_file
):
super
(
AEVComputer
,
self
).
__init__
(
benchmark
)
self
.
dtype
=
dtype
self
.
const_file
=
const_file
self
.
device
=
device
# load constants from const file
const
=
{}
with
open
(
const_file
)
as
f
:
for
i
in
f
:
try
:
...
...
@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule):
'ShfZ'
,
'EtaA'
,
'ShfA'
]:
value
=
[
float
(
x
.
strip
())
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
value
=
torch
.
tensor
(
value
,
dtype
=
dtype
,
device
=
device
)
setattr
(
self
,
name
,
value
)
value
=
torch
.
tensor
(
value
)
const
[
name
]
=
value
elif
name
==
'Atyp'
:
value
=
[
x
.
strip
()
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
...
...
@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule):
raise
ValueError
(
'unable to parse const file'
)
# Compute lengths
self
.
radial_sublength
=
self
.
EtaR
.
shape
[
0
]
*
self
.
ShfR
.
shape
[
0
]
self
.
radial_sublength
=
const
[
'
EtaR
'
]
.
shape
[
0
]
*
const
[
'
ShfR
'
]
.
shape
[
0
]
self
.
radial_length
=
len
(
self
.
species
)
*
self
.
radial_sublength
self
.
angular_sublength
=
self
.
EtaA
.
shape
[
0
]
*
\
self
.
Zeta
.
shape
[
0
]
*
self
.
ShfA
.
shape
[
0
]
*
self
.
ShfZ
.
shape
[
0
]
self
.
angular_sublength
=
const
[
'EtaA'
].
shape
[
0
]
*
\
const
[
'Zeta'
].
shape
[
0
]
*
const
[
'ShfA'
].
shape
[
0
]
*
\
const
[
'ShfZ'
].
shape
[
0
]
species
=
len
(
self
.
species
)
self
.
angular_length
=
int
(
(
species
*
(
species
+
1
))
/
2
)
*
self
.
angular_sublength
...
...
@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule):
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self
.
EtaR
=
self
.
EtaR
.
view
(
-
1
,
1
)
self
.
ShfR
=
self
.
ShfR
.
view
(
1
,
-
1
)
const
[
'
EtaR
'
]
=
const
[
'
EtaR
'
]
.
view
(
-
1
,
1
)
const
[
'
ShfR
'
]
=
const
[
'
ShfR
'
]
.
view
(
1
,
-
1
)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self
.
EtaA
=
self
.
EtaA
.
view
(
-
1
,
1
,
1
,
1
)
self
.
Zeta
=
self
.
Zeta
.
view
(
1
,
-
1
,
1
,
1
)
self
.
ShfA
=
self
.
ShfA
.
view
(
1
,
1
,
-
1
,
1
)
self
.
ShfZ
=
self
.
ShfZ
.
view
(
1
,
1
,
1
,
-
1
)
const
[
'EtaA'
]
=
const
[
'EtaA'
].
view
(
-
1
,
1
,
1
,
1
)
const
[
'Zeta'
]
=
const
[
'Zeta'
].
view
(
1
,
-
1
,
1
,
1
)
const
[
'ShfA'
]
=
const
[
'ShfA'
].
view
(
1
,
1
,
-
1
,
1
)
const
[
'ShfZ'
]
=
const
[
'ShfZ'
].
view
(
1
,
1
,
1
,
-
1
)
# register buffers
for
i
in
const
:
self
.
register_buffer
(
i
,
const
[
i
])
def
forward
(
self
,
coordinates_species
):
"""Compute AEV from coordinates and species
...
...
@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule):
class
PrepareInput
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
species
,
device
):
def
__init__
(
self
,
species
):
super
(
PrepareInput
,
self
).
__init__
()
self
.
species
=
species
self
.
device
=
device
def
species_to_tensor
(
self
,
species
):
def
species_to_tensor
(
self
,
species
,
device
):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
device : torch.device
The device to store tensor
Returns
-------
...
...
@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module):
"""
indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
values
=
[
indices
[
i
]
for
i
in
species
]
return
torch
.
tensor
(
values
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
torch
.
tensor
(
values
,
dtype
=
torch
.
long
,
device
=
device
)
def
sort_by_species
(
self
,
species
,
*
tensors
):
"""Sort the data by its species according to the order in `self.species`
...
...
@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module):
def
forward
(
self
,
species_coordinates
):
species
,
coordinates
=
species_coordinates
species
=
self
.
species_to_tensor
(
species
)
species
=
self
.
species_to_tensor
(
species
,
coordinates
.
device
)
return
self
.
sort_by_species
(
species
,
coordinates
)
...
...
@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer):
total : total time for computing everything.
"""
def
__init__
(
self
,
benchmark
=
False
,
device
=
default_device
,
dtype
=
default_dtype
,
const_file
=
buildin_const_file
):
super
(
SortedAEV
,
self
).
__init__
(
benchmark
,
dtype
,
device
,
const_file
)
def
__init__
(
self
,
benchmark
=
False
,
const_file
=
buildin_const_file
):
super
(
SortedAEV
,
self
).
__init__
(
benchmark
,
const_file
)
if
benchmark
:
self
.
radial_subaev_terms
=
self
.
_enable_benchmark
(
self
.
radial_subaev_terms
,
'radial terms'
)
...
...
@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer):
storing the mask for each species.
"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
len
(
self
.
species
),
device
=
self
.
device
))
torch
.
arange
(
len
(
self
.
species
),
device
=
self
.
EtaR
.
device
))
return
mask_r
def
compute_mask_a
(
self
,
species_a
,
present_species
):
...
...
@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer):
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
self
.
dtype
)).
sum
(
-
3
)
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
"""shape (conformations, atoms, present species, radial_length)"""
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
...
...
@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer):
zero_angular_subaev
=
torch
.
zeros
(
# TODO: can we make stack and cat broadcast?
conformations
,
atoms
,
self
.
angular_sublength
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
)
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
range
(
len
(
self
.
species
)),
2
):
if
s1
in
rev_indices
and
s2
in
rev_indices
:
i1
=
rev_indices
[
s1
]
i2
=
rev_indices
[
s2
]
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
dtype
)
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
EtaR
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
subaev
=
zero_angular_subaev
...
...
torchani/data.py
View file @
d3ae0788
...
...
@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader
from
os.path
import
join
,
isfile
,
isdir
import
os
from
.pyanitools
import
anidataloader
from
.env
import
default_dtype
,
default_device
import
torch
import
torch.utils.data
as
data
import
pickle
...
...
@@ -11,7 +10,8 @@ import pickle
class
ANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
],
transform
=
(),
dtype
=
default_dtype
,
device
=
default_device
):
transform
=
(),
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
super
(
ANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
chunks_size
=
chunk_size
...
...
torchani/env.py
View file @
d3ae0788
import
pkg_resources
import
torch
buildin_const_file
=
pkg_resources
.
resource_filename
(
...
...
@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename(
__name__
,
'resources/ani-1x_dft_x8ens/train'
)
buildin_ensemble
=
8
default_dtype
=
torch
.
float32
default_device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment