Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
9639d716
Commit
9639d716
authored
Aug 21, 2019
by
Richard Xue
Committed by
Gao, Xiang
Aug 21, 2019
Browse files
Add split to new dataset API (#299)
* split * clean * docs * docs * Update new.py
parent
b9e2c259
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
16 deletions
+78
-16
docs/api.rst
docs/api.rst
+1
-0
tests/test_data_new.py
tests/test_data_new.py
+13
-5
torchani/data/new.py
torchani/data/new.py
+64
-11
No files found.
docs/api.rst
View file @
9639d716
...
@@ -27,6 +27,7 @@ Datasets
...
@@ -27,6 +27,7 @@ Datasets
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset
.. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset
.. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
.. autoclass:: torchani.data.BatchedANIDataset
...
...
tests/test_data_new.py
View file @
9639d716
...
@@ -54,6 +54,12 @@ class TestShuffledData(unittest.TestCase):
...
@@ -54,6 +54,12 @@ class TestShuffledData(unittest.TestCase):
for
i
,
_
in
enumerate
(
self
.
ds
):
for
i
,
_
in
enumerate
(
self
.
ds
):
pbar
.
update
(
i
)
pbar
.
update
(
i
)
def
testSplitDataset
(
self
):
print
(
'=> test splitting dataset'
)
train_ds
,
val_ds
=
torchani
.
data
.
ShuffledDataset
(
dspath
,
batch_size
=
batch_size
,
chunk_threshold
=
chunk_threshold
,
num_workers
=
2
,
validation_split
=
0.1
)
frac
=
len
(
val_ds
)
/
(
len
(
val_ds
)
+
len
(
train_ds
))
self
.
assertLess
(
abs
(
frac
-
0.1
),
0.05
)
def
testNoUnnecessaryPadding
(
self
):
def
testNoUnnecessaryPadding
(
self
):
print
(
'=> checking No Unnecessary Padding'
)
print
(
'=> checking No Unnecessary Padding'
)
for
i
,
chunk
in
enumerate
(
self
.
chunks
):
for
i
,
chunk
in
enumerate
(
self
.
chunks
):
...
@@ -91,11 +97,13 @@ class TestCachedData(unittest.TestCase):
...
@@ -91,11 +97,13 @@ class TestCachedData(unittest.TestCase):
def
testLoadDataset
(
self
):
def
testLoadDataset
(
self
):
print
(
'=> test loading all dataset'
)
print
(
'=> test loading all dataset'
)
pbar
=
pkbar
.
Pbar
(
'loading and processing dataset into cpu memory, total '
self
.
ds
.
load
()
+
'batches: {}, batch_size: {}'
.
format
(
len
(
self
.
ds
),
batch_size
),
len
(
self
.
ds
))
def
testSplitDataset
(
self
):
for
i
,
_
in
enumerate
(
self
.
ds
):
print
(
'=> test splitting dataset'
)
pbar
.
update
(
i
)
train_dataset
,
val_dataset
=
self
.
ds
.
split
(
0.1
)
frac
=
len
(
val_dataset
)
/
len
(
self
.
ds
)
self
.
assertLess
(
abs
(
frac
-
0.1
),
0.05
)
def
testNoUnnecessaryPadding
(
self
):
def
testNoUnnecessaryPadding
(
self
):
print
(
'=> checking No Unnecessary Padding'
)
print
(
'=> checking No Unnecessary Padding'
)
...
...
torchani/data/new.py
View file @
9639d716
...
@@ -76,8 +76,8 @@ class CachedDataset(torch.utils.data.Dataset):
...
@@ -76,8 +76,8 @@ class CachedDataset(torch.utils.data.Dataset):
anidata
=
anidataloader
(
file_path
)
anidata
=
anidataloader
(
file_path
)
anidata_size
=
anidata
.
group_size
()
anidata_size
=
anidata
.
group_size
()
enable_pkbar
=
anidata_size
>
5
and
PKBAR_INSTALLED
self
.
enable_pkbar
=
anidata_size
>
5
and
PKBAR_INSTALLED
if
enable_pkbar
:
if
self
.
enable_pkbar
:
pbar
=
pkbar
.
Pbar
(
'=> loading h5 dataset into cpu memory, total molecules: {}'
.
format
(
anidata_size
),
anidata_size
)
pbar
=
pkbar
.
Pbar
(
'=> loading h5 dataset into cpu memory, total molecules: {}'
.
format
(
anidata_size
),
anidata_size
)
for
i
,
molecule
in
enumerate
(
anidata
):
for
i
,
molecule
in
enumerate
(
anidata
):
...
@@ -92,7 +92,7 @@ class CachedDataset(torch.utils.data.Dataset):
...
@@ -92,7 +92,7 @@ class CachedDataset(torch.utils.data.Dataset):
self_energies
=
np
.
array
(
sum
([
self_energies_dict
[
x
]
for
x
in
molecule
[
'species'
]]))
self_energies
=
np
.
array
(
sum
([
self_energies_dict
[
x
]
for
x
in
molecule
[
'species'
]]))
self
.
data_self_energies
+=
list
(
np
.
tile
(
self_energies
,
(
num_conformations
,
1
)))
self
.
data_self_energies
+=
list
(
np
.
tile
(
self_energies
,
(
num_conformations
,
1
)))
if
enable_pkbar
:
if
self
.
enable_pkbar
:
pbar
.
update
(
i
)
pbar
.
update
(
i
)
if
subtract_self_energies
:
if
subtract_self_energies
:
...
@@ -172,6 +172,43 @@ class CachedDataset(torch.utils.data.Dataset):
...
@@ -172,6 +172,43 @@ class CachedDataset(torch.utils.data.Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
length
return
self
.
length
def
split
(
self
,
validation_split
):
"""Split dataset into traning and validaiton.
Arguments:
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
"""
val_size
=
int
(
validation_split
*
len
(
self
))
train_size
=
len
(
self
)
-
val_size
ds
=
[]
if
self
.
enable_pkbar
:
message
=
(
'=> processing, splitting and caching dataset into cpu memory:
\n
'
+
'total batches: {}, train batches: {}, val batches: {}, batch_size: {}'
)
pbar
=
pkbar
.
Pbar
(
message
.
format
(
len
(
self
),
train_size
,
val_size
,
self
.
batch_size
),
len
(
self
))
for
i
,
_
in
enumerate
(
self
):
ds
.
append
(
self
[
i
])
if
self
.
enable_pkbar
:
pbar
.
update
(
i
)
train_dataset
=
ds
[:
train_size
]
val_dataset
=
ds
[
train_size
:]
return
train_dataset
,
val_dataset
def
load
(
self
):
"""Cache dataset into CPU memory. If not called, dataset will be cached during the first epoch.
"""
if
self
.
enable_pkbar
:
pbar
=
pkbar
.
Pbar
(
'=> processing and caching dataset into cpu memory:
\n
total '
+
'batches: {}, batch_size: {}'
.
format
(
len
(
self
),
self
.
batch_size
),
len
(
self
))
for
i
,
_
in
enumerate
(
self
):
if
self
.
enable_pkbar
:
pbar
.
update
(
i
)
@
staticmethod
@
staticmethod
def
sort_list_with_index
(
inputs
,
index
):
def
sort_list_with_index
(
inputs
,
index
):
return
[
inputs
[
i
]
for
i
in
index
]
return
[
inputs
[
i
]
for
i
in
index
]
...
@@ -229,6 +266,7 @@ class CachedDataset(torch.utils.data.Dataset):
...
@@ -229,6 +266,7 @@ class CachedDataset(torch.utils.data.Dataset):
def
ShuffledDataset
(
file_path
,
def
ShuffledDataset
(
file_path
,
batch_size
=
1000
,
num_workers
=
0
,
shuffle
=
True
,
chunk_threshold
=
20
,
batch_size
=
1000
,
num_workers
=
0
,
shuffle
=
True
,
chunk_threshold
=
20
,
validation_split
=
0.0
,
species_order
=
[
'H'
,
'C'
,
'N'
,
'O'
],
species_order
=
[
'H'
,
'C'
,
'N'
,
'O'
],
subtract_self_energies
=
False
,
subtract_self_energies
=
False
,
self_energies
=
[
-
0.600953
,
-
38.08316
,
-
54.707756
,
-
75.194466
]):
self_energies
=
[
-
0.600953
,
-
38.08316
,
-
54.707756
,
-
75.194466
]):
...
@@ -242,6 +280,8 @@ def ShuffledDataset(file_path,
...
@@ -242,6 +280,8 @@ def ShuffledDataset(file_path,
shuffle (bool): whether to shuffle.
shuffle (bool): whether to shuffle.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None``
chunk_threshold (int): threshould to split batch into chunks. Set to ``None``
will not split chunks.
will not split chunks.
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
species_order (list): a list which specify how species are transfomed to int.
species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
...
@@ -273,14 +313,27 @@ def ShuffledDataset(file_path,
...
@@ -273,14 +313,27 @@ def ShuffledDataset(file_path,
def
my_collate_fn
(
data
,
chunk_threshold
=
chunk_threshold
):
def
my_collate_fn
(
data
,
chunk_threshold
=
chunk_threshold
):
return
collate_fn
(
data
,
chunk_threshold
)
return
collate_fn
(
data
,
chunk_threshold
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
dataset
,
val_size
=
int
(
validation_split
*
len
(
dataset
))
batch_size
=
batch_size
,
train_size
=
len
(
dataset
)
-
val_size
shuffle
=
shuffle
,
train_dataset
,
val_dataset
=
torch
.
utils
.
data
.
random_split
(
dataset
,
[
train_size
,
val_size
])
num_workers
=
num_workers
,
pin_memory
=
False
,
train_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
train_dataset
,
collate_fn
=
my_collate_fn
)
batch_size
=
batch_size
,
shuffle
=
shuffle
,
return
data_loader
num_workers
=
num_workers
,
pin_memory
=
False
,
collate_fn
=
my_collate_fn
)
if
val_size
==
0
:
return
train_data_loader
val_data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
=
val_dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
num_workers
,
pin_memory
=
False
,
collate_fn
=
my_collate_fn
)
return
train_data_loader
,
val_data_loader
class
TorchData
(
torch
.
utils
.
data
.
Dataset
):
class
TorchData
(
torch
.
utils
.
data
.
Dataset
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment