Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
7c597a39
Unverified
Commit
7c597a39
authored
Sep 09, 2018
by
Gao, Xiang
Committed by
GitHub
Sep 09, 2018
Browse files
Make AEVCacheLoader a Dataset (#93)
parent
83107add
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
4 deletions
+3
-4
torchani/data/__init__.py
torchani/data/__init__.py
+3
-4
No files found.
torchani/data/__init__.py
View file @
7c597a39
...
@@ -219,7 +219,7 @@ class BatchedANIDataset(Dataset):
...
@@ -219,7 +219,7 @@ class BatchedANIDataset(Dataset):
return
len
(
self
.
batches
)
return
len
(
self
.
batches
)
class
AEVCacheLoader
:
class
AEVCacheLoader
(
Dataset
)
:
"""Build a factory for AEV.
"""Build a factory for AEV.
The computation of AEV is the most time consuming part during training.
The computation of AEV is the most time consuming part during training.
...
@@ -233,6 +233,7 @@ class AEVCacheLoader:
...
@@ -233,6 +233,7 @@ class AEVCacheLoader:
"""
"""
def
__init__
(
self
,
disk_cache
=
None
):
def
__init__
(
self
,
disk_cache
=
None
):
super
(
AEVCacheLoader
,
self
).
__init__
()
self
.
disk_cache
=
disk_cache
self
.
disk_cache
=
disk_cache
# load dataset from disk cache
# load dataset from disk cache
...
@@ -241,12 +242,10 @@ class AEVCacheLoader:
...
@@ -241,12 +242,10 @@ class AEVCacheLoader:
self
.
dataset
=
pickle
.
load
(
f
)
self
.
dataset
=
pickle
.
load
(
f
)
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
if
index
>=
self
.
__len__
():
_
,
output
=
self
.
dataset
.
batches
[
index
]
raise
IndexError
()
aev_path
=
os
.
path
.
join
(
self
.
disk_cache
,
str
(
index
))
aev_path
=
os
.
path
.
join
(
self
.
disk_cache
,
str
(
index
))
with
open
(
aev_path
,
'rb'
)
as
f
:
with
open
(
aev_path
,
'rb'
)
as
f
:
species_aevs
=
pickle
.
load
(
f
)
species_aevs
=
pickle
.
load
(
f
)
_
,
output
=
self
.
dataset
.
batches
[
index
]
return
species_aevs
,
output
return
species_aevs
,
output
def
__len__
(
self
):
def
__len__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment