Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
19d63943
Unverified
Commit
19d63943
authored
Sep 05, 2023
by
Rhett Ying
Committed by
GitHub
Sep 05, 2023
Browse files
[GraphBolt] avoid warning of pickle local function (#6284)
parent
29949322
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
18 deletions
+20
-18
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+20
-18
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
19d63943
...
@@ -282,6 +282,25 @@ class ItemSampler(IterDataPipe):
...
@@ -282,6 +282,25 @@ class ItemSampler(IterDataPipe):
self
.
_drop_last
=
drop_last
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
@
staticmethod
def
_collate
(
batch
):
"""Collate items into a batch. For internal use only."""
data
=
next
(
iter
(
batch
))
if
isinstance
(
data
,
DGLGraph
):
return
dgl_batch
(
batch
)
elif
isinstance
(
data
,
Mapping
):
assert
len
(
data
)
==
1
,
"Only one type of data is allowed."
# Collect all the keys.
keys
=
{
key
for
item
in
batch
for
key
in
item
.
keys
()}
# Collate each key.
return
{
key
:
default_collate
(
[
item
[
key
]
for
item
in
batch
if
key
in
item
]
)
for
key
in
keys
}
return
default_collate
(
batch
)
def
__iter__
(
self
)
->
Iterator
:
def
__iter__
(
self
)
->
Iterator
:
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
# Shuffle before batch.
# Shuffle before batch.
...
@@ -299,24 +318,7 @@ class ItemSampler(IterDataPipe):
...
@@ -299,24 +318,7 @@ class ItemSampler(IterDataPipe):
)
)
# Collate.
# Collate.
def
_collate
(
batch
):
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
self
.
_collate
)
data
=
next
(
iter
(
batch
))
if
isinstance
(
data
,
DGLGraph
):
return
dgl_batch
(
batch
)
elif
isinstance
(
data
,
Mapping
):
assert
len
(
data
)
==
1
,
"Only one type of data is allowed."
# Collect all the keys.
keys
=
{
key
for
item
in
batch
for
key
in
item
.
keys
()}
# Collate each key.
return
{
key
:
default_collate
(
[
item
[
key
]
for
item
in
batch
if
key
in
item
]
)
for
key
in
keys
}
return
default_collate
(
batch
)
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
partial
(
_collate
))
# Map to minibatch.
# Map to minibatch.
data_pipe
=
data_pipe
.
map
(
data_pipe
=
data_pipe
.
map
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment