Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
d9c0130f
Unverified
Commit
d9c0130f
authored
May 25, 2019
by
Gao, Xiang
Committed by
GitHub
May 25, 2019
Browse files
Refactor cache_aev and cache_sparse_aev (#232)
parent
2ec2fb6d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
72 deletions
+46
-72
torchani/data/__init__.py
torchani/data/__init__.py
+46
-72
No files found.
torchani/data/__init__.py
View file @
d9c0130f
...
...
@@ -276,13 +276,27 @@ class AEVCacheLoader(Dataset):
aev_path
=
os
.
path
.
join
(
self
.
disk_cache
,
str
(
index
))
with
open
(
aev_path
,
'rb'
)
as
f
:
species_aevs
=
pickle
.
load
(
f
)
for
i
,
sa
in
enumerate
(
species_aevs
):
species
,
aevs
=
self
.
decode_aev
(
*
sa
)
species_aevs
[
i
]
=
(
species
.
to
(
self
.
dataset
.
device
),
aevs
.
to
(
self
.
dataset
.
device
)
)
return
species_aevs
,
output
def
__len__
(
self
):
return
len
(
self
.
dataset
)
@
staticmethod
def
decode_aev
(
encoded_species
,
encoded_aev
):
return
encoded_species
,
encoded_aev
class
SparseAEVCacheLoader
(
Dataset
):
@
staticmethod
def
encode_aev
(
species
,
aev
):
return
species
,
aev
class
SparseAEVCacheLoader
(
AEVCacheLoader
):
"""Build a factory for AEV.
The computation of AEV is the most time-consuming part of the training.
...
...
@@ -294,61 +308,26 @@ class SparseAEVCacheLoader(Dataset):
Arguments:
disk_cache (str): Directory storing disk caches.
device (:class:`torch.dtype`): device to put tensors when iterating.
"""
def
__init__
(
self
,
disk_cache
=
None
,
device
=
torch
.
device
(
'cpu'
)):
super
(
SparseAEVCacheLoader
,
self
).
__init__
()
self
.
disk_cache
=
disk_cache
self
.
device
=
device
# load dataset from disk cache
dataset_path
=
os
.
path
.
join
(
disk_cache
,
'dataset'
)
with
open
(
dataset_path
,
'rb'
)
as
f
:
self
.
dataset
=
pickle
.
load
(
f
)
@
staticmethod
def
decode_aev
(
encoded_species
,
encoded_aev
):
species
=
torch
.
from_numpy
(
encoded_species
.
todense
())
aevs_np
=
np
.
stack
([
np
.
array
(
i
.
todense
())
for
i
in
encoded_aev
],
axis
=
0
)
aevs
=
torch
.
from_numpy
(
aevs_np
)
return
species
,
aevs
def
__getitem__
(
self
,
index
):
_
,
output
=
self
.
dataset
.
batches
[
index
]
aev_path
=
os
.
path
.
join
(
self
.
disk_cache
,
str
(
index
))
with
open
(
aev_path
,
'rb'
)
as
f
:
species_aevs
=
pickle
.
load
(
f
)
batch_X
=
[]
for
species_
,
aev_
in
species_aevs
:
species_np
=
np
.
array
(
species_
.
todense
())
species
=
torch
.
from_numpy
(
species_np
).
to
(
self
.
device
)
aevs_np
=
np
.
stack
([
np
.
array
(
i
.
todense
())
for
i
in
aev_
],
axis
=
0
)
aevs
=
torch
.
from_numpy
(
aevs_np
).
to
(
self
.
device
)
batch_X
.
append
((
species
,
aevs
))
return
batch_X
,
output
def
__len__
(
self
):
return
len
(
self
.
dataset
)
@
staticmethod
def
encode_aev
(
species
,
aev
):
encoded_species
=
bsr_matrix
(
species
.
cpu
().
numpy
())
encoded_aev
=
[
bsr_matrix
(
i
.
cpu
().
numpy
())
for
i
in
aev
]
return
encoded_species
,
encoded_aev
builtin
=
neurochem
.
Builtins
()
def
cache_aev
(
output
,
dataset_path
,
batchsize
,
device
=
default_device
,
constfile
=
builtin
.
const_file
,
subtract_sae
=
False
,
sae_file
=
builtin
.
sae_file
,
enable_tqdm
=
True
,
**
kwargs
):
# if output directory does not exist, then create it
if
not
os
.
path
.
exists
(
output
):
os
.
makedirs
(
output
)
device
=
torch
.
device
(
device
)
consts
=
neurochem
.
Constants
(
constfile
)
aev_computer
=
aev
.
AEVComputer
(
**
consts
).
to
(
device
)
if
subtract_sae
:
energy_shifter
=
neurochem
.
load_sae
(
sae_file
)
transform
=
(
energy_shifter
.
subtract_from_dataset
,)
else
:
transform
=
()
dataset
=
BatchedANIDataset
(
dataset_path
,
consts
.
species_to_tensor
,
batchsize
,
device
=
device
,
transform
=
transform
,
**
kwargs
)
def
create_aev_cache
(
dataset
,
aev_computer
,
output
,
enable_tqdm
=
True
,
encoder
=
lambda
x
:
x
):
# dump out the dataset
filename
=
os
.
path
.
join
(
output
,
'dataset'
)
with
open
(
filename
,
'wb'
)
as
f
:
...
...
@@ -361,15 +340,14 @@ def cache_aev(output, dataset_path, batchsize, device=default_device,
indices
=
range
(
len
(
dataset
))
for
i
in
indices
:
input_
,
_
=
dataset
[
i
]
aevs
=
[
aev_computer
(
j
)
for
j
in
input_
]
aevs
=
[
encoder
(
*
aev_computer
(
j
)
)
for
j
in
input_
]
filename
=
os
.
path
.
join
(
output
,
'{}'
.
format
(
i
))
with
open
(
filename
,
'wb'
)
as
f
:
pickle
.
dump
(
aevs
,
f
)
def
cache_sparse_aev
(
output
,
dataset_path
,
batchsize
,
device
=
default_device
,
constfile
=
builtin
.
const_file
,
subtract_sae
=
False
,
sae_file
=
builtin
.
sae_file
,
enable_tqdm
=
True
,
**
kwargs
):
def
_cache_aev
(
output
,
dataset_path
,
batchsize
,
device
,
constfile
,
subtract_sae
,
sae_file
,
enable_tqdm
,
encoder
,
**
kwargs
):
# if output directory does not exist, then create it
if
not
os
.
path
.
exists
(
output
):
os
.
makedirs
(
output
)
...
...
@@ -389,27 +367,23 @@ def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
device
=
device
,
transform
=
transform
,
**
kwargs
)
# dump out the dataset
filename
=
os
.
path
.
join
(
output
,
'dataset'
)
with
open
(
filename
,
'wb'
)
as
f
:
pickle
.
dump
(
dataset
,
f
)
create_aev_cache
(
dataset
,
aev_computer
,
output
,
enable_tqdm
,
encoder
)
if
enable_tqdm
:
import
tqdm
indices
=
tqdm
.
trange
(
len
(
dataset
))
else
:
indices
=
range
(
len
(
dataset
))
for
i
in
indices
:
input_
,
_
=
dataset
[
i
]
aevs
=
[]
for
j
in
input_
:
species_
,
aev_
=
aev_computer
(
j
)
species_
=
bsr_matrix
(
species_
.
cpu
().
numpy
())
aev_
=
[
bsr_matrix
(
i
.
cpu
().
numpy
())
for
i
in
aev_
]
aevs
.
append
((
species_
,
aev_
))
filename
=
os
.
path
.
join
(
output
,
'{}'
.
format
(
i
))
with
open
(
filename
,
'wb'
)
as
f
:
pickle
.
dump
(
aevs
,
f
)
def
cache_aev
(
output
,
dataset_path
,
batchsize
,
device
=
default_device
,
constfile
=
builtin
.
const_file
,
subtract_sae
=
False
,
sae_file
=
builtin
.
sae_file
,
enable_tqdm
=
True
,
**
kwargs
):
_cache_aev
(
output
,
dataset_path
,
batchsize
,
device
,
constfile
,
subtract_sae
,
sae_file
,
enable_tqdm
,
AEVCacheLoader
.
encode_aev
,
**
kwargs
)
def
cache_sparse_aev
(
output
,
dataset_path
,
batchsize
,
device
=
default_device
,
constfile
=
builtin
.
const_file
,
subtract_sae
=
False
,
sae_file
=
builtin
.
sae_file
,
enable_tqdm
=
True
,
**
kwargs
):
_cache_aev
(
output
,
dataset_path
,
batchsize
,
device
,
constfile
,
subtract_sae
,
sae_file
,
enable_tqdm
,
SparseAEVCacheLoader
.
encode_aev
,
**
kwargs
)
__all__
=
[
'BatchedANIDataset'
,
'AEVCacheLoader'
,
'SparseAEVCacheLoader'
,
'cache_aev'
,
'cache_sparse_aev'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment