You need to sign in or sign up before continuing.
Unverified Commit 1aa77d8d authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Add possibility to use custom padding (address issue #458) (#489)

* add possibility to use custom padding

* Add padding_dict to actual function

* Simplify arguments and fix docs

* typo
parent c2517397
......@@ -35,7 +35,11 @@ Available transformations are listed below:
- `shuffle`
- `cache` cache the result of previous transformations.
- `collate` pad the dataset, convert it to tensor, and stack them
together to get a batch.
together to get a batch. `collate` uses a default padding dictionary
``{'species': -1, 'coordinates': 0.0, 'forces': 0.0, 'energies': 0.0}`` for
padding, but a custom padding dictionary can be passed as an optional
parameter, which overrides this default padding.
- `pin_memory` copy the tensor to pinned memory so that later transfer
to cuda could be faster.
......@@ -94,8 +98,8 @@ if PKBAR_INSTALLED:
verbose = True
PROPERTIES = ('energies',)
PADDING = {
'species': -1,
'coordinates': 0.0,
......@@ -104,8 +108,11 @@ PADDING = {
}
def collate_fn(samples):
return utils.stack_with_padding(samples, PADDING)
def collate_fn(samples, padding=None):
if padding is None:
padding = PADDING
return utils.stack_with_padding(samples, padding)
class IterableAdapter:
......@@ -241,8 +248,8 @@ class Transformations:
return ret
@staticmethod
def collate(reenterable_iterable, batch_size):
def reenterable_iterable_factory():
def collate(reenterable_iterable, batch_size, padding=None):
def reenterable_iterable_factory(padding=None):
batch = []
i = 0
for d in reenterable_iterable:
......@@ -250,10 +257,13 @@ class Transformations:
i += 1
if i == batch_size:
i = 0
yield collate_fn(batch)
yield collate_fn(batch, padding)
batch = []
if len(batch) > 0:
yield collate_fn(batch)
yield collate_fn(batch, padding)
reenterable_iterable_factory = functools.partial(reenterable_iterable_factory,
padding)
try:
length = (len(reenterable_iterable) + batch_size - 1) // batch_size
return IterableAdapterWithLength(reenterable_iterable_factory, length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment