Unverified Commit 1aa77d8d authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Add possibility to use custom padding (address issue #458) (#489)

* add possibility to use custom padding

* Add padding_dict to actual function

* Simplify arguments and fix docs

* typo
parent c2517397
......@@ -35,7 +35,11 @@ Available transformations are listed below:
- `shuffle`
- `cache` cache the result of previous transformations.
- `collate` pad the dataset, convert it to tensor, and stack them
together to get a batch.
together to get a batch. `collate` uses a default padding dictionary
``{'species': -1, 'coordinates': 0.0, 'forces': 0.0, 'energies': 0.0}`` for
padding, but a custom padding dictionary can be passed as an optional
parameter, which overrides this default padding.
- `pin_memory` copy the tensor to pinned memory so that later transfer
to cuda could be faster.
......@@ -94,8 +98,8 @@ if PKBAR_INSTALLED:
verbose = True
PROPERTIES = ('energies',)
PADDING = {
'species': -1,
'coordinates': 0.0,
......@@ -104,8 +108,11 @@ PADDING = {
}
def collate_fn(samples):
return utils.stack_with_padding(samples, PADDING)
def collate_fn(samples, padding=None):
if padding is None:
padding = PADDING
return utils.stack_with_padding(samples, padding)
class IterableAdapter:
......@@ -241,8 +248,8 @@ class Transformations:
return ret
@staticmethod
def collate(reenterable_iterable, batch_size):
def reenterable_iterable_factory():
def collate(reenterable_iterable, batch_size, padding=None):
def reenterable_iterable_factory(padding=None):
batch = []
i = 0
for d in reenterable_iterable:
......@@ -250,10 +257,13 @@ class Transformations:
i += 1
if i == batch_size:
i = 0
yield collate_fn(batch)
yield collate_fn(batch, padding)
batch = []
if len(batch) > 0:
yield collate_fn(batch)
yield collate_fn(batch, padding)
reenterable_iterable_factory = functools.partial(reenterable_iterable_factory,
padding)
try:
length = (len(reenterable_iterable) + batch_size - 1) // batch_size
return IterableAdapterWithLength(reenterable_iterable_factory, length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment