Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
19d63943
"vscode:/vscode.git/clone" did not exist on "6b02babbadce55093b3de0f47a144c5574162f31"
Unverified
Commit
19d63943
authored
Sep 05, 2023
by
Rhett Ying
Committed by
GitHub
Sep 05, 2023
Browse files
[GraphBolt] avoid warning of pickle local function (#6284)
parent
29949322
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
18 deletions
+20
-18
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+20
-18
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
19d63943
...
@@ -282,24 +282,9 @@ class ItemSampler(IterDataPipe):
...
@@ -282,24 +282,9 @@ class ItemSampler(IterDataPipe):
self
.
_drop_last
=
drop_last
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
def
__iter__
(
self
)
->
Iterator
:
@
staticmethod
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
# Shuffle before batch.
if
self
.
_shuffle
:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size
=
max
(
10000
,
10
*
self
.
_batch_size
)
data_pipe
=
data_pipe
.
shuffle
(
buffer_size
=
buffer_size
)
# Batch.
data_pipe
=
data_pipe
.
batch
(
batch_size
=
self
.
_batch_size
,
drop_last
=
self
.
_drop_last
,
)
# Collate.
def
_collate
(
batch
):
def
_collate
(
batch
):
"""Collate items into a batch. For internal use only."""
data
=
next
(
iter
(
batch
))
data
=
next
(
iter
(
batch
))
if
isinstance
(
data
,
DGLGraph
):
if
isinstance
(
data
,
DGLGraph
):
return
dgl_batch
(
batch
)
return
dgl_batch
(
batch
)
...
@@ -316,7 +301,24 @@ class ItemSampler(IterDataPipe):
...
@@ -316,7 +301,24 @@ class ItemSampler(IterDataPipe):
}
}
return
default_collate
(
batch
)
return
default_collate
(
batch
)
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
partial
(
_collate
))
def
__iter__
(
self
)
->
Iterator
:
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
# Shuffle before batch.
if
self
.
_shuffle
:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# To ensure randomness, make sure the buffer size is at least 10
# times the batch size.
buffer_size
=
max
(
10000
,
10
*
self
.
_batch_size
)
data_pipe
=
data_pipe
.
shuffle
(
buffer_size
=
buffer_size
)
# Batch.
data_pipe
=
data_pipe
.
batch
(
batch_size
=
self
.
_batch_size
,
drop_last
=
self
.
_drop_last
,
)
# Collate.
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
self
.
_collate
)
# Map to minibatch.
# Map to minibatch.
data_pipe
=
data_pipe
.
map
(
data_pipe
=
data_pipe
.
map
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment