Unverified Commit 19d63943 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] avoid warning of pickle local function (#6284)

parent 29949322
......@@ -282,6 +282,25 @@ class ItemSampler(IterDataPipe):
self._drop_last = drop_last
self._shuffle = shuffle
@staticmethod
def _collate(batch):
"""Collate items into a batch. For internal use only."""
data = next(iter(batch))
if isinstance(data, DGLGraph):
return dgl_batch(batch)
elif isinstance(data, Mapping):
assert len(data) == 1, "Only one type of data is allowed."
# Collect all the keys.
keys = {key for item in batch for key in item.keys()}
# Collate each key.
return {
key: default_collate(
[item[key] for item in batch if key in item]
)
for key in keys
}
return default_collate(batch)
def __iter__(self) -> Iterator:
data_pipe = IterableWrapper(self._item_set)
# Shuffle before batch.
......@@ -299,24 +318,7 @@ class ItemSampler(IterDataPipe):
)
# Collate.
def _collate(batch):
data = next(iter(batch))
if isinstance(data, DGLGraph):
return dgl_batch(batch)
elif isinstance(data, Mapping):
assert len(data) == 1, "Only one type of data is allowed."
# Collect all the keys.
keys = {key for item in batch for key in item.keys()}
# Collate each key.
return {
key: default_collate(
[item[key] for item in batch if key in item]
)
for key in keys
}
return default_collate(batch)
data_pipe = data_pipe.collate(collate_fn=partial(_collate))
data_pipe = data_pipe.collate(collate_fn=self._collate)
# Map to minibatch.
data_pipe = data_pipe.map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment