Unverified Commit 1aa77d8d authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Add possibility to use custom padding (address issue #458) (#489)

* add possibility to use custom padding

* Add padding_dict to actual function

* Simplify arguments and fix docs

* typo
parent c2517397
...@@ -35,7 +35,11 @@ Available transformations are listed below: ...@@ -35,7 +35,11 @@ Available transformations are listed below:
- `shuffle` - `shuffle`
- `cache` cache the result of previous transformations. - `cache` cache the result of previous transformations.
- `collate` pad the dataset, convert it to tensor, and stack them - `collate` pad the dataset, convert it to tensor, and stack them
together to get a batch. together to get a batch. `collate` uses a default padding dictionary
``{'species': -1, 'coordinates': 0.0, 'forces': 0.0, 'energies': 0.0}`` for
padding, but a custom padding dictionary can be passed as an optional
parameter, which overrides this default padding.
- `pin_memory` copy the tensor to pinned memory so that later transfer - `pin_memory` copy the tensor to pinned memory so that later transfer
to cuda could be faster. to cuda could be faster.
...@@ -94,8 +98,8 @@ if PKBAR_INSTALLED: ...@@ -94,8 +98,8 @@ if PKBAR_INSTALLED:
verbose = True verbose = True
PROPERTIES = ('energies',) PROPERTIES = ('energies',)
PADDING = { PADDING = {
'species': -1, 'species': -1,
'coordinates': 0.0, 'coordinates': 0.0,
...@@ -104,8 +108,11 @@ PADDING = { ...@@ -104,8 +108,11 @@ PADDING = {
} }
def collate_fn(samples): def collate_fn(samples, padding=None):
return utils.stack_with_padding(samples, PADDING) if padding is None:
padding = PADDING
return utils.stack_with_padding(samples, padding)
class IterableAdapter: class IterableAdapter:
...@@ -241,8 +248,8 @@ class Transformations: ...@@ -241,8 +248,8 @@ class Transformations:
return ret return ret
@staticmethod @staticmethod
def collate(reenterable_iterable, batch_size): def collate(reenterable_iterable, batch_size, padding=None):
def reenterable_iterable_factory(): def reenterable_iterable_factory(padding=None):
batch = [] batch = []
i = 0 i = 0
for d in reenterable_iterable: for d in reenterable_iterable:
...@@ -250,10 +257,13 @@ class Transformations: ...@@ -250,10 +257,13 @@ class Transformations:
i += 1 i += 1
if i == batch_size: if i == batch_size:
i = 0 i = 0
yield collate_fn(batch) yield collate_fn(batch, padding)
batch = [] batch = []
if len(batch) > 0: if len(batch) > 0:
yield collate_fn(batch) yield collate_fn(batch, padding)
reenterable_iterable_factory = functools.partial(reenterable_iterable_factory,
padding)
try: try:
length = (len(reenterable_iterable) + batch_size - 1) // batch_size length = (len(reenterable_iterable) + batch_size - 1) // batch_size
return IterableAdapterWithLength(reenterable_iterable_factory, length) return IterableAdapterWithLength(reenterable_iterable_factory, length)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment