Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
d3ae0788
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d2e5cb3c1072ad324d1c9c4bf19be98bc4280282"
Unverified
Commit
d3ae0788
authored
Aug 02, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 02, 2018
Browse files
remove explicit device and dtype (#44)
parent
8c493a6e
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
78 additions
and
103 deletions
+78
-103
examples/energy_force.py
examples/energy_force.py
+2
-4
examples/model.py
examples/model.py
+3
-3
examples/nnp_training.py
examples/nnp_training.py
+4
-2
examples/training-benchmark.py
examples/training-benchmark.py
+4
-2
tests/test_aev.py
tests/test_aev.py
+3
-4
tests/test_batch.py
tests/test_batch.py
+3
-6
tests/test_benchmark.py
tests/test_benchmark.py
+8
-14
tests/test_energies.py
tests/test_energies.py
+3
-6
tests/test_ensemble.py
tests/test_ensemble.py
+2
-2
tests/test_forces.py
tests/test_forces.py
+3
-6
tests/test_ignite.py
tests/test_ignite.py
+3
-7
torchani/__init__.py
torchani/__init__.py
+2
-3
torchani/aev.py
torchani/aev.py
+36
-38
torchani/data.py
torchani/data.py
+2
-2
torchani/env.py
torchani/env.py
+0
-4
No files found.
examples/energy_force.py
View file @
d3ae0788
...
@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
...
@@ -8,8 +8,8 @@ const_file = os.path.join(path, '../torchani/resources/ani-1x_dft_x8ens/rHCNO-5.
sae_file
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat'
)
# noqa: E501
sae_file
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/sae_linfit.dat'
)
# noqa: E501
network_dir
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/train'
)
# noqa: E501
network_dir
=
os
.
path
.
join
(
path
,
'../torchani/resources/ani-1x_dft_x8ens/train'
)
# noqa: E501
aev_computer
=
torchani
.
SortedAEV
(
const_file
=
const_file
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
(
const_file
=
const_file
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nn
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
from_
=
network_dir
,
nn
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
from_
=
network_dir
,
ensemble
=
8
)
ensemble
=
8
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nn
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nn
)
...
@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
...
@@ -20,8 +20,6 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[
-
0.66518241
,
-
0.84461308
,
0.20759389
],
[
-
0.66518241
,
-
0.84461308
,
0.20759389
],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.45554739
,
0.54289633
,
0.81170881
],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
[
0.66091919
,
-
0.16799635
,
-
0.91037834
]]],
dtype
=
aev_computer
.
dtype
,
device
=
aev_computer
.
device
,
requires_grad
=
True
)
requires_grad
=
True
)
species
=
[
'C'
,
'H'
,
'H'
,
'H'
,
'H'
]
species
=
[
'C'
,
'H'
,
'H'
,
'H'
,
'H'
]
...
...
examples/model.py
View file @
d3ae0788
...
@@ -18,9 +18,9 @@ def atomic():
...
@@ -18,9 +18,9 @@ def atomic():
def
get_or_create_model
(
filename
,
benchmark
=
False
,
def
get_or_create_model
(
filename
,
benchmark
=
False
,
device
=
torch
ani
.
default_device
):
device
=
torch
.
device
(
'cpu'
)
):
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
benchmark
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
benchmark
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
model
=
torchani
.
models
.
CustomModel
(
model
=
torchani
.
models
.
CustomModel
(
reducer
=
torch
.
sum
,
reducer
=
torch
.
sum
,
benchmark
=
benchmark
,
benchmark
=
benchmark
,
...
...
examples/nnp_training.py
View file @
d3ae0788
...
@@ -8,6 +8,8 @@ import timeit
...
@@ -8,6 +8,8 @@ import timeit
import
tensorboardX
import
tensorboardX
import
math
import
math
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
chunk_size
=
256
chunk_size
=
256
batch_chunks
=
4
batch_chunks
=
4
dataset_path
=
sys
.
argv
[
1
]
dataset_path
=
sys
.
argv
[
1
]
...
@@ -20,11 +22,11 @@ start = timeit.default_timer()
...
@@ -20,11 +22,11 @@ start = timeit.default_timer()
shift_energy
=
torchani
.
EnergyShifter
()
shift_energy
=
torchani
.
EnergyShifter
()
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
dataset_checkpoint
,
dataset_path
,
chunk_size
,
dataset_checkpoint
,
dataset_path
,
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
transform
=
[
shift_energy
.
dataset_subtract_sae
])
training
=
torchani
.
data
.
dataloader
(
training
,
batch_chunks
)
training
=
torchani
.
data
.
dataloader
(
training
,
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
model_checkpoint
)
nnp
=
model
.
get_or_create_model
(
model_checkpoint
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
examples/training-benchmark.py
View file @
d3ae0788
...
@@ -6,15 +6,17 @@ import timeit
...
@@ -6,15 +6,17 @@ import timeit
import
model
import
model
import
tqdm
import
tqdm
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
chunk_size
=
256
chunk_size
=
256
batch_chunks
=
4
batch_chunks
=
4
dataset_path
=
sys
.
argv
[
1
]
dataset_path
=
sys
.
argv
[
1
]
shift_energy
=
torchani
.
EnergyShifter
()
shift_energy
=
torchani
.
EnergyShifter
()
dataset
=
torchani
.
data
.
ANIDataset
(
dataset
=
torchani
.
data
.
ANIDataset
(
dataset_path
,
chunk_size
,
dataset_path
,
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
transform
=
[
shift_energy
.
dataset_subtract_sae
])
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
batch_chunks
)
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
tests/test_aev.py
View file @
d3ae0788
...
@@ -10,12 +10,11 @@ N = 97
...
@@ -10,12 +10,11 @@ N = 97
class
TestAEV
(
unittest
.
TestCase
):
class
TestAEV
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
):
def
setUp
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
aev_computer
=
torchani
.
SortedAEV
()
device
=
torch
.
device
(
'cpu'
))
self
.
radial_length
=
aev_computer
.
radial_length
self
.
radial_length
=
aev_computer
.
radial_length
self
.
aev
=
torch
.
nn
.
Sequential
(
self
.
aev
=
torch
.
nn
.
Sequential
(
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
),
torchani
.
PrepareInput
(
aev_computer
.
species
),
aev_computer
aev_computer
)
)
self
.
tolerance
=
1e-5
self
.
tolerance
=
1e-5
...
...
tests/test_batch.py
View file @
d3ae0788
...
@@ -12,17 +12,14 @@ if sys.version_info.major >= 3:
...
@@ -12,17 +12,14 @@ if sys.version_info.major >= 3:
path
=
os
.
path
.
join
(
path
,
'../dataset'
)
path
=
os
.
path
.
join
(
path
,
'../dataset'
)
chunksize
=
32
chunksize
=
32
batch_chunks
=
32
batch_chunks
=
32
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestBatch
(
unittest
.
TestCase
):
class
TestBatch
(
unittest
.
TestCase
):
def
testBatchLoadAndInference
(
self
):
def
testBatchLoadAndInference
(
self
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
device
=
device
)
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
model
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
model
)
...
...
tests/test_benchmark.py
View file @
d3ae0788
...
@@ -6,14 +6,10 @@ import copy
...
@@ -6,14 +6,10 @@ import copy
class
TestBenchmark
(
unittest
.
TestCase
):
class
TestBenchmark
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
def
setUp
(
self
):
device
=
torchani
.
default_device
):
self
.
dtype
=
dtype
self
.
device
=
device
self
.
conformations
=
100
self
.
conformations
=
100
self
.
species
=
list
(
'HHCCNNOO'
)
self
.
species
=
list
(
'HHCCNNOO'
)
self
.
coordinates
=
torch
.
randn
(
self
.
coordinates
=
torch
.
randn
(
self
.
conformations
,
8
,
3
)
self
.
conformations
,
8
,
3
,
dtype
=
dtype
,
device
=
device
)
self
.
count
=
100
self
.
count
=
100
def
_testModule
(
self
,
run_module
,
result_module
,
asserts
):
def
_testModule
(
self
,
run_module
,
result_module
,
asserts
):
...
@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase):
...
@@ -82,9 +78,8 @@ class TestBenchmark(unittest.TestCase):
self
.
assertEqual
(
result_module
.
timers
[
i
],
0
)
self
.
assertEqual
(
result_module
.
timers
[
i
],
0
)
def
testAEV
(
self
):
def
testAEV
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
(
benchmark
=
True
)
benchmark
=
True
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
)
self
.
_testModule
(
run_module
,
aev_computer
,
[
self
.
_testModule
(
run_module
,
aev_computer
,
[
'terms and indices>radial terms'
,
'terms and indices>radial terms'
,
...
@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase):
...
@@ -95,11 +90,10 @@ class TestBenchmark(unittest.TestCase):
])
])
def
testANIModel
(
self
):
def
testANIModel
(
self
):
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
()
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
self
.
device
)
model
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
model
=
torchani
.
models
.
NeuroChemNNP
(
benchmark
=
True
)
aev_computer
.
species
,
benchmark
=
True
).
to
(
self
.
device
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
run_module
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
)
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
self
.
_testModule
(
run_module
,
model
,
[
'forward'
])
...
...
tests/test_energies.py
View file @
d3ae0788
...
@@ -11,13 +11,10 @@ N = 97
...
@@ -11,13 +11,10 @@ N = 97
class
TestEnergies
(
unittest
.
TestCase
):
class
TestEnergies
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
def
setUp
(
self
):
device
=
torchani
.
default_device
):
self
.
tolerance
=
5e-5
self
.
tolerance
=
5e-5
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
()
dtype
=
dtype
,
device
=
torch
.
device
(
'cpu'
))
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
...
...
tests/test_ensemble.py
View file @
d3ae0788
...
@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase):
...
@@ -18,8 +18,8 @@ class TestEnsemble(unittest.TestCase):
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
n
=
torchani
.
buildin_ensemble
n
=
torchani
.
buildin_ensemble
prefix
=
torchani
.
buildin_model_prefix
prefix
=
torchani
.
buildin_model_prefix
aev
=
torchani
.
SortedAEV
(
device
=
torch
.
device
(
'cpu'
)
)
aev
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev
.
species
,
aev
.
device
)
prepare
=
torchani
.
PrepareInput
(
aev
.
species
)
ensemble
=
torchani
.
models
.
NeuroChemNNP
(
aev
.
species
,
ensemble
=
True
)
ensemble
=
torchani
.
models
.
NeuroChemNNP
(
aev
.
species
,
ensemble
=
True
)
ensemble
=
torch
.
nn
.
Sequential
(
prepare
,
aev
,
ensemble
)
ensemble
=
torch
.
nn
.
Sequential
(
prepare
,
aev
,
ensemble
)
models
=
[
torchani
.
models
.
models
=
[
torchani
.
models
.
...
...
tests/test_forces.py
View file @
d3ae0788
...
@@ -10,13 +10,10 @@ N = 97
...
@@ -10,13 +10,10 @@ N = 97
class
TestForce
(
unittest
.
TestCase
):
class
TestForce
(
unittest
.
TestCase
):
def
setUp
(
self
,
dtype
=
torchani
.
default_dtype
,
def
setUp
(
self
):
device
=
torchani
.
default_device
):
self
.
tolerance
=
1e-5
self
.
tolerance
=
1e-5
aev_computer
=
torchani
.
SortedAEV
(
aev_computer
=
torchani
.
SortedAEV
()
dtype
=
dtype
,
device
=
torch
.
device
(
'cpu'
))
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
...
...
tests/test_ignite.py
View file @
d3ae0788
...
@@ -13,21 +13,17 @@ if sys.version_info.major >= 3:
...
@@ -13,21 +13,17 @@ if sys.version_info.major >= 3:
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
path
=
os
.
path
.
join
(
path
,
'../dataset/ani_gdb_s01.h5'
)
chunksize
=
4
chunksize
=
4
threshold
=
1e-5
threshold
=
1e-5
dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestIgnite
(
unittest
.
TestCase
):
class
TestIgnite
(
unittest
.
TestCase
):
def
testIgnite
(
self
):
def
testIgnite
(
self
):
shift_energy
=
torchani
.
EnergyShifter
()
shift_energy
=
torchani
.
EnergyShifter
()
ds
=
torchani
.
data
.
ANIDataset
(
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
device
=
device
,
path
,
chunksize
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
transform
=
[
shift_energy
.
dataset_subtract_sae
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
loader
=
torchani
.
data
.
dataloader
(
ds
,
1
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
1
)
aev_computer
=
torchani
.
SortedAEV
(
dtype
=
dtype
,
device
=
device
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
,
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
aev_computer
.
device
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
class
Flatten
(
torch
.
nn
.
Module
):
class
Flatten
(
torch
.
nn
.
Module
):
...
...
torchani/__init__.py
View file @
d3ae0788
...
@@ -4,10 +4,9 @@ from . import data
...
@@ -4,10 +4,9 @@ from . import data
from
.
import
ignite
from
.
import
ignite
from
.aev
import
SortedAEV
,
PrepareInput
from
.aev
import
SortedAEV
,
PrepareInput
from
.env
import
buildin_const_file
,
buildin_sae_file
,
buildin_network_dir
,
\
from
.env
import
buildin_const_file
,
buildin_sae_file
,
buildin_network_dir
,
\
buildin_model_prefix
,
buildin_ensemble
,
default_dtype
,
default_device
buildin_model_prefix
,
buildin_ensemble
__all__
=
[
'PrepareInput'
,
'SortedAEV'
,
'EnergyShifter'
,
__all__
=
[
'PrepareInput'
,
'SortedAEV'
,
'EnergyShifter'
,
'models'
,
'data'
,
'ignite'
,
'models'
,
'data'
,
'ignite'
,
'buildin_const_file'
,
'buildin_sae_file'
,
'buildin_network_dir'
,
'buildin_const_file'
,
'buildin_sae_file'
,
'buildin_network_dir'
,
'buildin_model_prefix'
,
'buildin_ensemble'
,
'buildin_model_prefix'
,
'buildin_ensemble'
]
'default_dtype'
,
'default_device'
]
torchani/aev.py
View file @
d3ae0788
import
torch
import
torch
import
itertools
import
itertools
import
math
import
math
from
.env
import
buildin_const_file
,
default_dtype
,
default_device
from
.env
import
buildin_const_file
from
.benchmarked
import
BenchmarkedModule
from
.benchmarked
import
BenchmarkedModule
class
AEVComputer
(
BenchmarkedModule
):
class
AEVComputer
(
BenchmarkedModule
):
__constants__
=
[
'Rcr'
,
'Rca'
,
'dtype'
,
'device'
,
'radial_sublength'
,
__constants__
=
[
'Rcr'
,
'Rca'
,
'radial_sublength'
,
'radial_length'
,
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
'aev_length'
]
"""Base class of various implementations of AEV computer
"""Base class of various implementations of AEV computer
...
@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule):
...
@@ -16,11 +15,6 @@ class AEVComputer(BenchmarkedModule):
----------
----------
benchmark : boolean
benchmark : boolean
Whether to enable benchmark
Whether to enable benchmark
dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is
also used to specify whether to use CPU or GPU.
device : torch.Device
The device where tensors should be.
const_file : str
const_file : str
The name of the original file that stores constant.
The name of the original file that stores constant.
Rcr, Rca : float
Rcr, Rca : float
...
@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule):
...
@@ -39,15 +33,12 @@ class AEVComputer(BenchmarkedModule):
The length of full aev
The length of full aev
"""
"""
def
__init__
(
self
,
benchmark
=
False
,
dtype
=
default_dtype
,
def
__init__
(
self
,
benchmark
=
False
,
const_file
=
buildin_const_file
):
device
=
default_device
,
const_file
=
buildin_const_file
):
super
(
AEVComputer
,
self
).
__init__
(
benchmark
)
super
(
AEVComputer
,
self
).
__init__
(
benchmark
)
self
.
dtype
=
dtype
self
.
const_file
=
const_file
self
.
const_file
=
const_file
self
.
device
=
device
# load constants from const file
# load constants from const file
const
=
{}
with
open
(
const_file
)
as
f
:
with
open
(
const_file
)
as
f
:
for
i
in
f
:
for
i
in
f
:
try
:
try
:
...
@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule):
...
@@ -60,8 +51,8 @@ class AEVComputer(BenchmarkedModule):
'ShfZ'
,
'EtaA'
,
'ShfA'
]:
'ShfZ'
,
'EtaA'
,
'ShfA'
]:
value
=
[
float
(
x
.
strip
())
for
x
in
value
.
replace
(
value
=
[
float
(
x
.
strip
())
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
value
=
torch
.
tensor
(
value
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
tensor
(
value
)
setattr
(
self
,
name
,
value
)
const
[
name
]
=
value
elif
name
==
'Atyp'
:
elif
name
==
'Atyp'
:
value
=
[
x
.
strip
()
for
x
in
value
.
replace
(
value
=
[
x
.
strip
()
for
x
in
value
.
replace
(
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
'['
,
''
).
replace
(
']'
,
''
).
split
(
','
)]
...
@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule):
...
@@ -70,10 +61,11 @@ class AEVComputer(BenchmarkedModule):
raise
ValueError
(
'unable to parse const file'
)
raise
ValueError
(
'unable to parse const file'
)
# Compute lengths
# Compute lengths
self
.
radial_sublength
=
self
.
EtaR
.
shape
[
0
]
*
self
.
ShfR
.
shape
[
0
]
self
.
radial_sublength
=
const
[
'
EtaR
'
]
.
shape
[
0
]
*
const
[
'
ShfR
'
]
.
shape
[
0
]
self
.
radial_length
=
len
(
self
.
species
)
*
self
.
radial_sublength
self
.
radial_length
=
len
(
self
.
species
)
*
self
.
radial_sublength
self
.
angular_sublength
=
self
.
EtaA
.
shape
[
0
]
*
\
self
.
angular_sublength
=
const
[
'EtaA'
].
shape
[
0
]
*
\
self
.
Zeta
.
shape
[
0
]
*
self
.
ShfA
.
shape
[
0
]
*
self
.
ShfZ
.
shape
[
0
]
const
[
'Zeta'
].
shape
[
0
]
*
const
[
'ShfA'
].
shape
[
0
]
*
\
const
[
'ShfZ'
].
shape
[
0
]
species
=
len
(
self
.
species
)
species
=
len
(
self
.
species
)
self
.
angular_length
=
int
(
self
.
angular_length
=
int
(
(
species
*
(
species
+
1
))
/
2
)
*
self
.
angular_sublength
(
species
*
(
species
+
1
))
/
2
)
*
self
.
angular_sublength
...
@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule):
...
@@ -81,13 +73,17 @@ class AEVComputer(BenchmarkedModule):
# convert constant tensors to a ready-to-broadcast shape
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
# shape convension (..., EtaR, ShfR)
self
.
EtaR
=
self
.
EtaR
.
view
(
-
1
,
1
)
const
[
'
EtaR
'
]
=
const
[
'
EtaR
'
]
.
view
(
-
1
,
1
)
self
.
ShfR
=
self
.
ShfR
.
view
(
1
,
-
1
)
const
[
'
ShfR
'
]
=
const
[
'
ShfR
'
]
.
view
(
1
,
-
1
)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
# shape convension (..., EtaA, Zeta, ShfA, ShfZ)
self
.
EtaA
=
self
.
EtaA
.
view
(
-
1
,
1
,
1
,
1
)
const
[
'EtaA'
]
=
const
[
'EtaA'
].
view
(
-
1
,
1
,
1
,
1
)
self
.
Zeta
=
self
.
Zeta
.
view
(
1
,
-
1
,
1
,
1
)
const
[
'Zeta'
]
=
const
[
'Zeta'
].
view
(
1
,
-
1
,
1
,
1
)
self
.
ShfA
=
self
.
ShfA
.
view
(
1
,
1
,
-
1
,
1
)
const
[
'ShfA'
]
=
const
[
'ShfA'
].
view
(
1
,
1
,
-
1
,
1
)
self
.
ShfZ
=
self
.
ShfZ
.
view
(
1
,
1
,
1
,
-
1
)
const
[
'ShfZ'
]
=
const
[
'ShfZ'
].
view
(
1
,
1
,
1
,
-
1
)
# register buffers
for
i
in
const
:
self
.
register_buffer
(
i
,
const
[
i
])
def
forward
(
self
,
coordinates_species
):
def
forward
(
self
,
coordinates_species
):
"""Compute AEV from coordinates and species
"""Compute AEV from coordinates and species
...
@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule):
...
@@ -112,18 +108,19 @@ class AEVComputer(BenchmarkedModule):
class
PrepareInput
(
torch
.
nn
.
Module
):
class
PrepareInput
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
species
,
device
):
def
__init__
(
self
,
species
):
super
(
PrepareInput
,
self
).
__init__
()
super
(
PrepareInput
,
self
).
__init__
()
self
.
species
=
species
self
.
species
=
species
self
.
device
=
device
def
species_to_tensor
(
self
,
species
):
def
species_to_tensor
(
self
,
species
,
device
):
"""Convert species list into a long tensor.
"""Convert species list into a long tensor.
Parameters
Parameters
----------
----------
species : list
species : list
List of string for the species of each atoms.
List of string for the species of each atoms.
device : torch.device
The device to store tensor
Returns
Returns
-------
-------
...
@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module):
...
@@ -133,7 +130,7 @@ class PrepareInput(torch.nn.Module):
"""
"""
indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
indices
=
{
self
.
species
[
i
]:
i
for
i
in
range
(
len
(
self
.
species
))}
values
=
[
indices
[
i
]
for
i
in
species
]
values
=
[
indices
[
i
]
for
i
in
species
]
return
torch
.
tensor
(
values
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
return
torch
.
tensor
(
values
,
dtype
=
torch
.
long
,
device
=
device
)
def
sort_by_species
(
self
,
species
,
*
tensors
):
def
sort_by_species
(
self
,
species
,
*
tensors
):
"""Sort the data by its species according to the order in `self.species`
"""Sort the data by its species according to the order in `self.species`
...
@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module):
...
@@ -158,7 +155,7 @@ class PrepareInput(torch.nn.Module):
def
forward
(
self
,
species_coordinates
):
def
forward
(
self
,
species_coordinates
):
species
,
coordinates
=
species_coordinates
species
,
coordinates
=
species_coordinates
species
=
self
.
species_to_tensor
(
species
)
species
=
self
.
species_to_tensor
(
species
,
coordinates
.
device
)
return
self
.
sort_by_species
(
species
,
coordinates
)
return
self
.
sort_by_species
(
species
,
coordinates
)
...
@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer):
...
@@ -203,9 +200,8 @@ class SortedAEV(AEVComputer):
total : total time for computing everything.
total : total time for computing everything.
"""
"""
def
__init__
(
self
,
benchmark
=
False
,
device
=
default_device
,
def
__init__
(
self
,
benchmark
=
False
,
const_file
=
buildin_const_file
):
dtype
=
default_dtype
,
const_file
=
buildin_const_file
):
super
(
SortedAEV
,
self
).
__init__
(
benchmark
,
const_file
)
super
(
SortedAEV
,
self
).
__init__
(
benchmark
,
dtype
,
device
,
const_file
)
if
benchmark
:
if
benchmark
:
self
.
radial_subaev_terms
=
self
.
_enable_benchmark
(
self
.
radial_subaev_terms
=
self
.
_enable_benchmark
(
self
.
radial_subaev_terms
,
'radial terms'
)
self
.
radial_subaev_terms
,
'radial terms'
)
...
@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer):
...
@@ -385,7 +381,7 @@ class SortedAEV(AEVComputer):
storing the mask for each species.
storing the mask for each species.
"""
"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
len
(
self
.
species
),
device
=
self
.
device
))
torch
.
arange
(
len
(
self
.
species
),
device
=
self
.
EtaR
.
device
))
return
mask_r
return
mask_r
def
compute_mask_a
(
self
,
species_a
,
present_species
):
def
compute_mask_a
(
self
,
species_a
,
present_species
):
...
@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer):
...
@@ -451,8 +447,10 @@ class SortedAEV(AEVComputer):
atoms
=
radial_terms
.
shape
[
1
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
present_radial_aevs
=
(
*
mask_r
.
unsqueeze
(
-
1
).
type
(
self
.
dtype
)).
sum
(
-
3
)
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
"""shape (conformations, atoms, present species, radial_length)"""
"""shape (conformations, atoms, present species, radial_length)"""
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
...
@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer):
...
@@ -466,13 +464,13 @@ class SortedAEV(AEVComputer):
zero_angular_subaev
=
torch
.
zeros
(
zero_angular_subaev
=
torch
.
zeros
(
# TODO: can we make stack and cat broadcast?
# TODO: can we make stack and cat broadcast?
conformations
,
atoms
,
self
.
angular_sublength
,
conformations
,
atoms
,
self
.
angular_sublength
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
dtype
=
self
.
EtaR
.
dtype
,
device
=
self
.
EtaR
.
device
)
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
for
s1
,
s2
in
itertools
.
combinations_with_replacement
(
range
(
len
(
self
.
species
)),
2
):
range
(
len
(
self
.
species
)),
2
):
if
s1
in
rev_indices
and
s2
in
rev_indices
:
if
s1
in
rev_indices
and
s2
in
rev_indices
:
i1
=
rev_indices
[
s1
]
i1
=
rev_indices
[
s1
]
i2
=
rev_indices
[
s2
]
i2
=
rev_indices
[
s2
]
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
dtype
)
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
EtaR
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
else
:
subaev
=
zero_angular_subaev
subaev
=
zero_angular_subaev
...
...
torchani/data.py
View file @
d3ae0788
...
@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader
...
@@ -2,7 +2,6 @@ from torch.utils.data import Dataset, DataLoader
from
os.path
import
join
,
isfile
,
isdir
from
os.path
import
join
,
isfile
,
isdir
import
os
import
os
from
.pyanitools
import
anidataloader
from
.pyanitools
import
anidataloader
from
.env
import
default_dtype
,
default_device
import
torch
import
torch
import
torch.utils.data
as
data
import
torch.utils.data
as
data
import
pickle
import
pickle
...
@@ -11,7 +10,8 @@ import pickle
...
@@ -11,7 +10,8 @@ import pickle
class
ANIDataset
(
Dataset
):
class
ANIDataset
(
Dataset
):
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
],
def
__init__
(
self
,
path
,
chunk_size
,
shuffle
=
True
,
properties
=
[
'energies'
],
transform
=
(),
dtype
=
default_dtype
,
device
=
default_device
):
transform
=
(),
dtype
=
torch
.
get_default_dtype
(),
device
=
torch
.
device
(
'cpu'
)):
super
(
ANIDataset
,
self
).
__init__
()
super
(
ANIDataset
,
self
).
__init__
()
self
.
path
=
path
self
.
path
=
path
self
.
chunks_size
=
chunk_size
self
.
chunks_size
=
chunk_size
...
...
torchani/env.py
View file @
d3ae0788
import
pkg_resources
import
pkg_resources
import
torch
buildin_const_file
=
pkg_resources
.
resource_filename
(
buildin_const_file
=
pkg_resources
.
resource_filename
(
...
@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename(
...
@@ -15,6 +14,3 @@ buildin_model_prefix = pkg_resources.resource_filename(
__name__
,
'resources/ani-1x_dft_x8ens/train'
)
__name__
,
'resources/ani-1x_dft_x8ens/train'
)
buildin_ensemble
=
8
buildin_ensemble
=
8
default_dtype
=
torch
.
float32
default_device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment