Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
d3ae0788
Unverified
Commit
d3ae0788
authored
Aug 02, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 02, 2018
Browse files
remove explicit device and dtype (#44)
parent
8c493a6e
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
78 additions
and
103 deletions
+78
-103
examples/energy_force.py
examples/energy_force.py
+2
-4
examples/model.py
examples/model.py
+3
-3
examples/nnp_training.py
examples/nnp_training.py
+4
-2
examples/training-benchmark.py
examples/training-benchmark.py
+4
-2
tests/test_aev.py
tests/test_aev.py
+3
-4
tests/test_batch.py
tests/test_batch.py
+3
-6
tests/test_benchmark.py
tests/test_benchmark.py
+8
-14
tests/test_energies.py
tests/test_energies.py
+3
-6
tests/test_ensemble.py
tests/test_ensemble.py
+2
-2
tests/test_forces.py
tests/test_forces.py
+3
-6
tests/test_ignite.py
tests/test_ignite.py
+3
-7
torchani/__init__.py
torchani/__init__.py
+2
-3
torchani/aev.py
torchani/aev.py
+36
-38
torchani/data.py
torchani/data.py
+2
-2
torchani/env.py
torchani/env.py
+0
-4
No files found.
examples/energy_force.py
View file @
d3ae0788
...
@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
...
@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
sae_file
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat'
)
# noqa: E501
sae_file
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat'
)
# noqa: E501
network_dir
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/train'
)
# noqa: E501
network_dir
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/train'
)
# noqa: E501
aev_computer
=
torchani
.
SortedAEV
(
const_file
=
const_file
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
(
const_file
=
const_file
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nn
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
from_
=
network_dir
,
nn
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
from_
=
network_dir
,
ensemble
=
8
)
ensemble
=
8
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nn
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nn
)
...
@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
...
@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[
-
0.66518241
,
-
0.84461308
,
0.20759389
],
[
-
0.66518241
,
-
0.84461308
,
0.20759389
],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
dtype
=
aev_computer
.
dtype
,
device
=
aev_computer
.
device
,
requires_grad
=
True
)
requires_grad
=
True
)
species
=
[
'C'
,
'H'
,
'H'
,
'H'
,
'H'
]
species
=
[
'C'
,
'H'
,
'H'
,
'H'
,
'H'
]
...
...
examples/model.py
View file @
d3ae0788
...
@@ -18,9 +18,9 @@ def atomic():
...
@@ -18,9 +18,9 @@ def atomic():
def
get_or_create_model
(
filename
,
benchmark
=
False
,
def
get_or_create_model
(
filename
,
benchmark
=
False
,
device
=
torch
ani
.
default_device
):
device
=
torch
.
device
(
'cpu'
)
):
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
benchmark
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
benchmark
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
model
=
torchani
.
models
.
CustomModel
(
model
=
torchani
.
models
.
CustomModel
(
reducer
=
torch
.
sum
,
reducer
=
torch
.
sum
,
benchmark
=
benchmark
,
benchmark
=
benchmark
,
...
...
examples/nnp_training.py
View file @
d3ae0788
...
@@ -8,6 +8,8 @@ import timeit
...
@@ -8,6 +8,8 @@ import timeit
import
tensorboardX
import
tensorboardX
import
math
import
math
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
chunk_size
=
256
chunk_size
=
256
batch_chunks
=
4
batch_chunks
=
4
dataset_path
=
sys
.
argv
[
1
]
dataset_path
=
sys
.
argv
[
1
]
...
@@ -20,11 +22,11 @@ start = timeit.default_timer()
...
@@ -20,11 +22,11 @@ start = timeit.default_timer()
shift_energy
=
torchani
.
EnergyShifter
()
shift_energy
=
torchani
.
EnergyShifter
()
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
dataset_checkpoint
,
dataset_path
,
chunk_size
,
dataset_checkpoint
,
dataset_path
,
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
transform
=
[
shift_energy
.
dataset_subtract_sae
])
training
=
torchani
.
data
.
dataloader
(
training
,
batch_chunks
)
training
=
torchani
.
data
.
dataloader
(
training
,
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
model_checkpoint
)
nnp
=
model
.
get_or_create_model
(
model_checkpoint
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
examples/training-benchmark.py
View file @
d3ae0788
...
@@ -6,15 +6,17 @@ import timeit
...
@@ -6,15 +6,17 @@ import timeit
import
model
import
model
import
tqdm
import
tqdm
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
chunk_size
=
256
chunk_size
=
256
batch_chunks
=
4
batch_chunks
=
4
dataset_path
=
sys
.
argv
[
1
]
dataset_path
=
sys
.
argv
[
1
]
shift_energy
=
torchani
.
EnergyShifter
()
shift_energy
=
torchani
.
EnergyShifter
()
dataset
=
torchani
.
data
.
ANIDataset
(
dataset
=
torchani
.
data
.
ANIDataset
(
dataset_path
,
chunk_size
,
dataset_path
,
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
transform
=
[
shift_energy
.
dataset_subtract_sae
])
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
batch_chunks
)
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
tests/test_aev.py
View file @
d3ae0788
...
@@ -10,12 +10,11 @@ N = 97
...
@@ -10,12 +10,11 @@ N = 97
class
TestAEV
(
unittest
.
TestCase
):
class
TestAEV
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
):
def
setUp
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
aev_computer
=
torchani
.
SortedAEV
()
device
=
torch
.
device
(
'cpu'
))
self
.
radial_length
=
aev_computer
.
radial_length
self
.
radial_length
=
aev_computer
.
radial_length
self
.
aev
=
torch
.
nn
.
Sequential
(
self
.
aev
=
torch
.
nn
.
Sequential
(
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
),
torchani
.
PrepareInput
(
aev_computer
.
species
),
aev_computer
aev_computer
)
)
self
.
tolerance
=
1e-5
self
.
tolerance
=
1e-5
...
...
tests/test_batch.py
View file @
d3ae0788
...
@@ -12,17 +12,14 @@ if sys.version_info.major >= 3:
...
@@ -12,17 +12,14 @@ if sys.version_info.major >= 3:
path
=
os
.
path
.
join
(
path
,
'../dataset'
)
path
=
os
.
path
.
join
(
path
,
'../dataset'
)
chunksize
=
32
chunksize
=
32
batch_chunks
=
32
batch_chunks
=
32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestBatch
(
unittest
.
TestCase
):
class
TestBatch
(
unittest
.
TestCase
):
def
testBatchLoadAndInference
(
self
):
def
testBatchLoadAndInference
(
self
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
device
=
device
)
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
model
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
model
)
...
...
tests/test_benchmark.py
View file @
d3ae0788
...
@@ -6,14 +6,10 @@ import copy
...
@@ -6,14 +6,10 @@ import copy
class
TestBenchmark
(
unittest
.
TestCase
):
class
TestBenchmark
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
def
setUp
(
self
):
device
=
torchani
.
default_device
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
conformations
=
100
self
.
conformations
=
100
self
.
species
=
list
(
'HHCCNNOO'
)
self
.
species
=
list
(
'HHCCNNOO'
)
self
.
coordinates
=
torch
.
randn
(
self
.
coordinates
=
torch
.
randn
(
self
.
conformations
,
8
,
3
)
self
.
conformations
,
8
,
3
,
dtype
=
dtype
,
device
=
device
)
self
.
count
=
100
self
.
count
=
100
def
_testModule
(
self
,
run_module
,
result_module
,
asserts
):
def
_testModule
(
self
,
run_module
,
result_module
,
asserts
):
...
@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase):
...
@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase):
self
.
assertEqual
(
result_module
.
timers
[
i
],
0
)
self
.
assertEqual
(
result_module
.
timers
[
i
],
0
)
def
testAEV
(
self
):
def
testAEV
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
True
)
benchmark
=
True
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
)
self
.
_testModule
(
run_module
,
aev_computer
,
[
self
.
_testModule
(
run_module
,
aev_computer
,
[
'terms and indices>radial terms'
,
'terms and indices>radial terms'
,
...
@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase):
...
@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase):
])
])
def
testANIModel
(
self
):
def
testANIModel
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
()
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
model
=
torchani
.
models
.
NeuroChemNNP
(
benchmark
=
True
)
aev_computer
.
species
,
benchmark
=
True
).
to
(
self
.
device
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
...
...
tests/test_energies.py
View file @
d3ae0788
...
@@ -11,13 +11,10 @@ N = 97
...
@@ -11,13 +11,10 @@ N = 97
class
TestEnergies
(
unittest
.
TestCase
):
class
TestEnergies
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
def
setUp
(
self
):
device
=
torchani
.
default_device
):
self
.
tolerance
=
5e-5
self
.
tolerance
=
5e-5
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
()
dtype
=
dtype
,
device
=
torch
.
device
(
'cpu'
))
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
...
...
tests/test_ensemble.py
View file @
d3ae0788
...
@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase):
...
@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase):
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
n
=
torchani
.
buildin_ensemble
n
=
torchani
.
buildin_ensemble
prefix
=
torchani
.
buildin_model_prefix
prefix
=
torchani
.
buildin_model_prefix
aev
=
torchani
.
SortedAEV
(
device
=
torch
.
device
(
'cpu'
)
)
aev
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev
.
species
,
aev
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev
.
species
)
ensemble
=
torchani
.
models
.
NeuroChemNNP
(
aev
.
species
,
ensemble
=
True
)
ensemble
=
torchani
.
models
.
NeuroChemNNP
(
aev
.
species
,
ensemble
=
True
)
ensemble
=
torch
.
nn
.
Sequential
(
prepare
,
aev
,
ensemble
)
ensemble
=
torch
.
nn
.
Sequential
(
prepare
,
aev
,
ensemble
)
models
=
[
torchani
.
models
.
models
=
[
torchani
.
models
.
...
...
tests/test_forces.py
View file @
d3ae0788
...
@@ -10,13 +10,10 @@ N = 97
...
@@ -10,13 +10,10 @@ N = 97
class
TestForce
(
unittest
.
TestCase
):
class
TestForce
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
def
setUp
(
self
):
device
=
torchani
.
default_device
):
self
.
tolerance
=
1e-5
self
.
tolerance
=
1e-5
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
()
dtype
=
dtype
,
device
=
torch
.
device
(
'cpu'
))
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
...
...
tests/test_ignite.py
View file @
d3ae0788
...
@@ -13,21 +13,17 @@ if sys.version_info.major >= 3:
...
@@ -13,21 +13,17 @@ if sys.version_info.major >= 3:
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
chunksize
=
4
chunksize
=
4
threshold
=
1e-5
threshold
=
1e-5
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestIgnite
(
unittest
.
TestCase
):
class
TestIgnite
(
unittest
.
TestCase
):
def
testIgnite
(
self
):
def
testIgnite
(
self
):
shift_energy
=
torchani
.
EnergyShifter
()
shift_energy
=
torchani
.
EnergyShifter
()
ds
=
torchani
.
data
.
ANIDataset
(
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
device
=
device
,
path
,
chunksize
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
transform
=
[
shift_energy
.
dataset_subtract_sae
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
loader
=
torchani
.
data
.
dataloader
(
ds
,
1
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
1
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
class
Flatten
(
torch
.
nn
.
Module
):
class
Flatten
(
torch
.
nn
.
Module
):
...
...
torchani/__init__.py
View file @
d3ae0788
...
@@ -4,10 +4,9 @@ from . import data
...
@@ -4,10 +4,9 @@ from . import data
from
.
import
ignite
from
.
import
ignite
from
.aev
import
SortedAEV
,
PrepareInput
from
.aev
import
SortedAEV
,
PrepareInput
from
.env
import
buildin_const_file
,
buildin_sae_file
,
buildin_network_dir
,
\
from
.env
import
buildin_const_file
,
buildin_sae_file
,
buildin_network_dir
,
\
buildin_model_prefix
,
buildin_ensemble
,
default_dtype
,
default_device
buildin_model_prefix
,
buildin_ensemble
__all__
=
[
'PrepareInput'
,
'SortedAEV'
,
'EnergyShifter'
,
__all__
=
[
'PrepareInput'
,
'SortedAEV'
,
'EnergyShifter'
,
'models'
,
'data'
,
'ignite'
,
'models'
,
'data'
,
'ignite'
,
'buildin_const_file'
,
'buildin_sae_file'
,
'buildin_network_dir'
,
'buildin_const_file'
,
'buildin_sae_file'
,
'buildin_network_dir'
,
'buildin_model_prefix'
,
'buildin_ensemble'
,
'buildin_model_prefix'
,
'buildin_ensemble'
]
'default_dtype'
,
'default_device'
]
torchani/aev.py
View file @
d3ae0788
import
torch
import
torch
import
itertools
import
itertools
import
math
import
math
from
.env
import
buildin_const_file
,
default_dtype
,
default_device
from
.env
import
buildin_const_file
from
.benchmarked
import
BenchmarkedModule
from
.benchmarked
import
BenchmarkedModule
class
AEVComputer
(
BenchmarkedModule
):
class
AEVComputer
(
BenchmarkedModule
):
__constants__
=
[
'Rcr'
,
'Rca'
,
'dtype'
,
'device'
,
'radial_sublength'
,
__constants__
=
[
'Rcr'
,
'Rca'
,
'radial_sublength'
,
'radial_length'
,
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
'aev_length'
]
"""Base class of various implementations of AEV computer
"""Base class of various implementations of AEV computer
...
@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule):
...
@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule):
----------
----------
benchmark : boolean
benchmark : boolean
Whether to enable benchmark
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
const_file : str
The name of the original file that stores constant.
The name of the original file that stores constant.
Rcr, Rca : float
Rcr, Rca : float
...
@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule):
...
@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule):
The length of full aev
The length of full aev
"""
"""
def
__init__
(
self
,
benchmark
=
False
,
dtype
=
default_dtype
,
def
__init__
(
self
,
benchmark
=
False
,
const_file
=
buildin_const_file
):
device
=
default_device
,
const_file
=
buildin_const_file
):
super
(
AEVComputer
,
self
).
__init__
(
benchmark
)
super
(
AEVComputer
,
self
).
__init__
(
benchmark
)
self
.
dtype
=
dtype
self
.
const_file
=
const_file
self
.
const_file
=
const_file
self
.
device
=
device
# load constants from const file
# load constants from const file
const
=
{}
with
open
(
const_file
)
as
f
:
with
open
(
const_file
)
as
f
:
for
i
in
f
:
for
i
in
f
:
try
:
try
:
...
@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule):
...
@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule):
'ShfZ'
,
'EtaA'
,
'ShfA'
]:
'ShfZ'
,
'EtaA'
,
'ShfA'
]:
value
=
[
float
(
x
.
strip
())
for
x
in
value
.
replace
(
value
=
[
float
(
x
.
strip
())
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
value
=
torch
.
tensor
(
value
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
tensor
(
value
)
setattr
(
self
,
name
,
value
)
const
[
name
]
=
value
elif
name
==
'Atyp'
:
elif
name
==
'Atyp'
:
value
=
[
x
.
strip
()
for
x
in
value
.
replace
(
value
=
[
x
.
strip
()
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
...
@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule):
...
@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule):
raise
ValueError
(
'unable to parse const file'
)
raise
ValueError
(
'unable to parse const file'
)
# Compute lengths
# Compute lengths
self
.
radial_sublength
=
self
.
EtaR
.
shape
[
0
]
*
self
.
ShfR
.
shape
[
0
]
self
.
radial_sublength
=
const
[
'
EtaR
'
]
.
shape
[
0
]
*
const
[
'
ShfR
'
]
.
shape
[
0
]
self
.
radial_length
=
len
(
self
.
species
)
*
self
.
radial_sublength
self
.
radial_length
=
len
(
self
.
species
)
*
self
.
radial_sublength
self
.
angular_sublength
=
self
.
EtaA
.
shape
[
0
]
*
\
self
.
angular_sublength
=
const
[
'EtaA'
].
shape
[
0
]
*
\
self
.
Zeta
.
shape
[
0
]
*
self
.
ShfA
.
shape
[
0
]
*
self
.
ShfZ
.
shape
[
0
]
const
[
'Zeta'
].
shape
[
0
]
*
const
[
'ShfA'
].
shape
[
0
]
*
\
const
[
'ShfZ'
].
shape
[
0
]
species
=
len
(
self
.
species
)
species
=
len
(
self
.
species
)
self
.
angular_length
=
int
(
self
.
angular_length
=
int
(
(
species
*
(
species
+
1
))
/
2
)
*
self
.
angular_sublength
(
species
*
(
species
+
1
))
/
2
)
*
self
.
angular_sublength
...
@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule):
...
@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule):
# convert constant tensors to a ready-to-broadcast shape
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
# shape convension (..., EtaR, ShfR)
self
.
EtaR
=
self
.
EtaR
.
view
(
-
1
,
1
)
const
[
'
EtaR
'
]
=
const
[
'
EtaR
'
]
.
view
(
-
1
,
1
)
self
.
ShfR
=
self
.
ShfR
.
view
(
1
,
-
1
)
const
[
'
ShfR
'
]
=
const
[
'
ShfR
'
]
.
view
(
1
,
-
1
)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self
.
EtaA
=
self
.
EtaA
.
view
(
-
1
,
1
,
1
,
1
)
const
[
'EtaA'
]
=
const
[
'EtaA'
].
view
(
-
1
,
1
,
1
,
1
)
self
.
Zeta
=
self
.
Zeta
.
view
(
1
,
-
1
,
1
,
1
)
const
[
'Zeta'
]
=
const
[
'Zeta'
].
view
(
1
,
-
1
,
1
,
1
)
self
.
ShfA
=
self
.
ShfA
.
view
(
1
,
1
,
-
1
,
1
)
const
[
'ShfA'
]
=
const
[
'ShfA'
].
view
(
1
,
1
,
-
1
,
1
)
self
.
ShfZ
=
self
.
ShfZ
.
view
(
1
,
1
,
1
,
-
1
)
const
[
'ShfZ'
]
=
const
[
'ShfZ'
].
view
(
1
,
1
,
1
,
-
1
)
# register buffers
for
i
in
const
:
self
.
register_buffer
(
i
,
const
[
i
])
def
forward
(
self
,
coordinates_species
):
def
forward
(
self
,
coordinates_species
):
"""Compute AEV from coordinates and species
"""Compute AEV from coordinates and species
...
@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule):
...
@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule):
class
PrepareInput
(
torch
.
nn
.
Module
):
class
PrepareInput
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
species
,
device
):
def
__init__
(
self
,
species
):
super
(
PrepareInput
,
self
).
__init__
()
super
(
PrepareInput
,
self
).
__init__
()
self
.
species
=
species
self
.
species
=
species
self
.
device
=
device
def
species_to_tensor
(
self
,
species
):
def
species_to_tensor
(
self
,
species
,
device
):
"""Convert species list into a long tensor.
"""Convert species list into a long tensor.
Parameters
Parameters
----------
----------
species : list
species : list
List of string for the species of each atoms.
List of string for the species of each atoms.
device : torch.device
The device to store tensor
Returns
Returns
-------
-------
...
@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module):
...
@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module):
"""
"""
indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
values
=
[
indices
[
i
]
for
i
in
species
]
values
=
[
indices
[
i
]
for
i
in
species
]
return
torch
.
tensor
(
values
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
torch
.
tensor
(
values
,
dtype
=
torch
.
long
,
device
=
device
)
def
sort_by_species
(
self
,
species
,
*
tensors
):
def
sort_by_species
(
self
,
species
,
*
tensors
):
"""Sort the data by its species according to the order in `self.species`
"""Sort the data by its species according to the order in `self.species`
...
@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module):
...
@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module):
def
forward
(
self
,
species_coordinates
):
def
forward
(
self
,
species_coordinates
):
species
,
coordinates
=
species_coordinates
species
,
coordinates
=
species_coordinates
species
=
self
.
species_to_tensor
(
species
)
species
=
self
.
species_to_tensor
(
species
,
coordinates
.
device
)
return
self
.
sort_by_species
(
species
,
coordinates
)
return
self
.
sort_by_species
(
species
,
coordinates
)
...
@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer):
...
@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer):
total : total time for computing everything.
total : total time for computing everything.
"""
"""
def
__init__
(
self
,
benchmark
=
False
,
device
=
default_device
,
def
__init__
(
self
,
benchmark
=
False
,
const_file
=
buildin_const_file
):
dtype
=
default_dtype
,
const_file
=
buildin_const_file
):
super
(
SortedAEV
,
self
).
__init__
(
benchmark
,
const_file
)
super
(
SortedAEV
,
self
).
__init__
(
benchmark
,
dtype
,
device
,
const_file
)
if
benchmark
:
if
benchmark
:
self
.
radial_subaev_terms
=
self
.
_enable_benchmark
(
self
.
radial_subaev_terms
=
self
.
_enable_benchmark
(
self
.
radial_subaev_terms
,
'radial terms'
)
self
.
radial_subaev_terms
,
'radial terms'
)
...
@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer):
...
@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer):
storing the mask for each species.
storing the mask for each species.
"""
"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
len
(
self
.
species
),
device
=
self
.
device
))
torch
.
arange
(
len
(
self
.
species
),
device
=
self
.
EtaR
.
device
))
return
mask_r
return
mask_r
def
compute_mask_a
(
self
,
species_a
,
present_species
):
def
compute_mask_a
(
self
,
species_a
,
present_species
):
...
@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer):
...
@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer):
atoms
=
radial_terms
.
shape
[
1
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
present_radial_aevs
=
(
*
mask_r
.
unsqueeze
(
-
1
).
type
(
self
.
dtype
)).
sum
(
-
3
)
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
"""shape (conformations, atoms, present species, radial_length)"""
"""shape (conformations, atoms, present species, radial_length)"""
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
...
@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer):
...
@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer):
zero_angular_subaev
=
torch
.
zeros
(
zero_angular_subaev
=
torch
.
zeros
(
# TODO: can we make stack and cat broadcast?
# TODO: can we make stack and cat broadcast?
conformations
,
atoms
,
self
.
angular_sublength
,
conformations
,
atoms
,
self
.
angular_sublength
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
)
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
range
(
len
(
self
.
species
)),
2
):
range
(
len
(
self
.
species
)),
2
):
if
s1
in
rev_indices
and
s2
in
rev_indices
:
if
s1
in
rev_indices
and
s2
in
rev_indices
:
i1
=
rev_indices
[
s1
]
i1
=
rev_indices
[
s1
]
i2
=
rev_indices
[
s2
]
i2
=
rev_indices
[
s2
]
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
dtype
)
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
EtaR
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
else
:
subaev
=
zero_angular_subaev
subaev
=
zero_angular_subaev
...
...
torchani/data.py
View file @
d3ae0788
...
@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader
...
@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader
from
os.path
import
join
,
isfile
,
isdir
from
os.path
import
join
,
isfile
,
isdir
import
os
import
os
from
.pyanitools
import
anidataloader
from
.pyanitools
import
anidataloader
from
.env
import
default_dtype
,
default_device
import
torch
import
torch
import
torch.utils.data
as
data
import
torch.utils.data
as
data
import
pickle
import
pickle
...
@@ -11,7 +10,8 @@ import pickle
...
@@ -11,7 +10,8 @@ import pickle
class
ANIDataset
(
Dataset
):
class
ANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
],
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
],
transform
=
(),
dtype
=
default_dtype
,
device
=
default_device
):
transform
=
(),
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
super
(
ANIDataset
,
self
).
__init__
()
super
(
ANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
path
=
path
self
.
chunks_size
=
chunk_size
self
.
chunks_size
=
chunk_size
...
...
torchani/env.py
View file @
d3ae0788
import
pkg_resources
import
pkg_resources
import
torch
buildin_const_file
=
pkg_resources
.
resource_filename
(
buildin_const_file
=
pkg_resources
.
resource_filename
(
...
@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename(
...
@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename(
__name__
,
'resources/ani-1x_dft_x8ens/train'
)
__name__
,
'resources/ani-1x_dft_x8ens/train'
)
buildin_ensemble
=
8
buildin_ensemble
=
8
default_dtype
=
torch
.
float32
default_device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment