Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
9639d716
Commit
9639d716
authored
Aug 21, 2019
by
Richard Xue
Committed by
Gao, Xiang
Aug 21, 2019
Browse files
Add split to new dataset API (#299)
* split * clean * docs * docs * Update new.py
parent
b9e2c259
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
16 deletions
+78
-16
docs/api.rst
docs/api.rst
+1
-0
tests/test_data_new.py
tests/test_data_new.py
+13
-5
torchani/data/new.py
torchani/data/new.py
+64
-11
No files found.
docs/api.rst
View file @
9639d716
...
...
@@ -27,6 +27,7 @@ Datasets
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
...
...
tests/test_data_new.py
View file @
9639d716
...
...
@@ -54,6 +54,12 @@ class TestShuffledData(unittest.TestCase):
for
i
,
_
in
enumerate
(
self
.
ds
):
pbar
.
update
(
i
)
def
testSplitDataset
(
self
):
print
(
'=> test splitting dataset'
)
train_ds
,
val_ds
=
torchani
.
data
.
ShuffledDataset
(
dspath
,
batch_size
=
batch_size
,
chunk_threshold
=
chunk_threshold
,
num_workers
=
2
,
validation_split
=
0.1
)
frac
=
len
(
val_ds
)
/
(
len
(
val_ds
)
+
len
(
train_ds
))
self
.
assertLess
(
abs
(
frac
-
0.1
),
0.05
)
def
testNoUnnecessaryPadding
(
self
):
print
(
'=> checking No Unnecessary Padding'
)
for
i
,
chunk
in
enumerate
(
self
.
chunks
):
...
...
@@ -91,11 +97,13 @@ class TestCachedData(unittest.TestCase):
def
testLoadDataset
(
self
):
print
(
'=> test loading all dataset'
)
pbar
=
pkbar
.
Pbar
(
'loading and processing dataset into cpu memory, total '
+
'batches: {}, batch_size: {}'
.
format
(
len
(
self
.
ds
),
batch_size
),
len
(
self
.
ds
))
for
i
,
_
in
enumerate
(
self
.
ds
):
pbar
.
update
(
i
)
self
.
ds
.
load
()
def
testSplitDataset
(
self
):
print
(
'=> test splitting dataset'
)
train_dataset
,
val_dataset
=
self
.
ds
.
split
(
0.1
)
frac
=
len
(
val_dataset
)
/
len
(
self
.
ds
)
self
.
assertLess
(
abs
(
frac
-
0.1
),
0.05
)
def
testNoUnnecessaryPadding
(
self
):
print
(
'=> checking No Unnecessary Padding'
)
...
...
torchani/data/new.py
View file @
9639d716
...
...
@@ -76,8 +76,8 @@ class CachedDataset(torch.utils.data.Dataset):
anidata
=
anidataloader
(
file_path
)
anidata_size
=
anidata
.
group_size
()
enable_pkbar
=
anidata_size
>
5
and
PKBAR_INSTALLED
if
enable_pkbar
:
self
.
enable_pkbar
=
anidata_size
>
5
and
PKBAR_INSTALLED
if
self
.
enable_pkbar
:
pbar
=
pkbar
.
Pbar
(
'=> loading h5 dataset into cpu memory, total molecules: {}'
.
format
(
anidata_size
),
anidata_size
)
for
i
,
molecule
in
enumerate
(
anidata
):
...
...
@@ -92,7 +92,7 @@ class CachedDataset(torch.utils.data.Dataset):
self_energies
=
np
.
array
(
sum
([
self_energies_dict
[
x
]
for
x
in
molecule
[
'species'
]]))
self
.
data_self_energies
+=
list
(
np
.
tile
(
self_energies
,
(
num_conformations
,
1
)))
if
enable_pkbar
:
if
self
.
enable_pkbar
:
pbar
.
update
(
i
)
if
subtract_self_energies
:
...
...
@@ -172,6 +172,43 @@ class CachedDataset(torch.utils.data.Dataset):
def
__len__
(
self
):
return
self
.
length
def
split
(
self
,
validation_split
):
"""Split dataset into traning and validaiton.
Arguments:
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
"""
val_size
=
int
(
validation_split
*
len
(
self
))
train_size
=
len
(
self
)
-
val_size
ds
=
[]
if
self
.
enable_pkbar
:
message
=
(
'=> processing, splitting and caching dataset into cpu memory:
\n
'
+
'total batches: {}, train batches: {}, val batches: {}, batch_size: {}'
)
pbar
=
pkbar
.
Pbar
(
message
.
format
(
len
(
self
),
train_size
,
val_size
,
self
.
batch_size
),
len
(
self
))
for
i
,
_
in
enumerate
(
self
):
ds
.
append
(
self
[
i
])
if
self
.
enable_pkbar
:
pbar
.
update
(
i
)
train_dataset
=
ds
[:
train_size
]
val_dataset
=
ds
[
train_size
:]
return
train_dataset
,
val_dataset
def
load
(
self
):
"""Cache dataset into CPU memory. If not called, dataset will be cached during the first epoch.
"""
if
self
.
enable_pkbar
:
pbar
=
pkbar
.
Pbar
(
'=> processing and caching dataset into cpu memory:
\n
total '
+
'batches: {}, batch_size: {}'
.
format
(
len
(
self
),
self
.
batch_size
),
len
(
self
))
for
i
,
_
in
enumerate
(
self
):
if
self
.
enable_pkbar
:
pbar
.
update
(
i
)
@
staticmethod
def
sort_list_with_index
(
inputs
,
index
):
return
[
inputs
[
i
]
for
i
in
index
]
...
...
@@ -229,6 +266,7 @@ class CachedDataset(torch.utils.data.Dataset):
def
ShuffledDataset
(
file_path
,
batch_size
=
1000
,
num_workers
=
0
,
shuffle
=
True
,
chunk_threshold
=
20
,
validation_split
=
0.0
,
species_order
=
[
'H'
,
'C'
,
'N'
,
'O'
],
subtract_self_energies
=
False
,
self_energies
=
[
-
0.600953
,
-
38.08316
,
-
54.707756
,
-
75.194466
]):
...
...
@@ -242,6 +280,8 @@ def ShuffledDataset(file_path,
shuffle (bool): whether to shuffle.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None``
will not split chunks.
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
...
...
@@ -273,14 +313,27 @@ def ShuffledDataset(file_path,
def
my_collate_fn
(
data
,
chunk_threshold
=
chunk_threshold
):
return
collate_fn
(
data
,
chunk_threshold
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
dataset
,
val_size
=
int
(
validation_split
*
len
(
dataset
))
train_size
=
len
(
dataset
)
-
val_size
train_dataset
,
val_dataset
=
torch
.
utils
.
data
.
random_split
(
dataset
,
[
train_size
,
val_size
])
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
num_workers
=
num_workers
,
pin_memory
=
False
,
collate_fn
=
my_collate_fn
)
if
val_size
==
0
:
return
train_data_loader
val_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
val_dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
num_workers
,
pin_memory
=
False
,
collate_fn
=
my_collate_fn
)
return
data_loader
return
train_data_loader
,
val_
data_loader
class
TorchData
(
torch
.
utils
.
data
.
Dataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment