"vscode:/vscode.git/clone" did not exist on "6b02babbadce55093b3de0f47a144c5574162f31"
Unverified Commit 19d63943 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] avoid warning of pickle local function (#6284)

parent 29949322
...@@ -282,24 +282,9 @@ class ItemSampler(IterDataPipe): ...@@ -282,24 +282,9 @@ class ItemSampler(IterDataPipe):
self._drop_last = drop_last self._drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
def __iter__(self) -> Iterator: @staticmethod
data_pipe = IterableWrapper(self._item_set)
# Shuffle before batch.
if self._shuffle:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size = max(10000, 10 * self._batch_size)
data_pipe = data_pipe.shuffle(buffer_size=buffer_size)
# Batch.
data_pipe = data_pipe.batch(
batch_size=self._batch_size,
drop_last=self._drop_last,
)
# Collate.
def _collate(batch): def _collate(batch):
"""Collate items into a batch. For internal use only."""
data = next(iter(batch)) data = next(iter(batch))
if isinstance(data, DGLGraph): if isinstance(data, DGLGraph):
return dgl_batch(batch) return dgl_batch(batch)
...@@ -316,7 +301,24 @@ class ItemSampler(IterDataPipe): ...@@ -316,7 +301,24 @@ class ItemSampler(IterDataPipe):
} }
return default_collate(batch) return default_collate(batch)
data_pipe = data_pipe.collate(collate_fn=partial(_collate)) def __iter__(self) -> Iterator:
data_pipe = IterableWrapper(self._item_set)
# Shuffle before batch.
if self._shuffle:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size = max(10000, 10 * self._batch_size)
data_pipe = data_pipe.shuffle(buffer_size=buffer_size)
# Batch.
data_pipe = data_pipe.batch(
batch_size=self._batch_size,
drop_last=self._drop_last,
)
# Collate.
data_pipe = data_pipe.collate(collate_fn=self._collate)
# Map to minibatch. # Map to minibatch.
data_pipe = data_pipe.map( data_pipe = data_pipe.map(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment