Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
617979d6
Unverified
Commit
617979d6
authored
Oct 10, 2023
by
Ramon Zhou
Committed by
GitHub
Oct 10, 2023
Browse files
[GraphBolt] Add to function for DGLMiniBatch and MiniBatch (#6413)
parent
38448dac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
0 deletions
+92
-0
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+39
-0
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+53
-0
No files found.
python/dgl/graphbolt/minibatch.py
View file @
617979d6
...
...
@@ -7,6 +7,7 @@ import torch
import
dgl
from
dgl.heterograph
import
DGLBlock
from
dgl.utils
import
recursive_apply
from
.base
import
etype_str_to_tuple
from
.sampled_subgraph
import
SampledSubgraph
...
...
@@ -95,6 +96,25 @@ class DGLMiniBatch:
given type.
"""
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
def
_to
(
x
,
device
):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
for
attr
in
dir
(
self
):
# Only copy member variables.
if
not
callable
(
getattr
(
self
,
attr
))
and
not
attr
.
startswith
(
"__"
):
setattr
(
self
,
attr
,
recursive_apply
(
getattr
(
self
,
attr
),
lambda
x
:
_to
(
x
,
device
)
),
)
return
self
@
dataclass
class
MiniBatch
:
...
...
@@ -374,6 +394,25 @@ class MiniBatch:
}
return
minibatch
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""
def
_to
(
x
,
device
):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
for
attr
in
dir
(
self
):
# Only copy member variables.
if
not
callable
(
getattr
(
self
,
attr
))
and
not
attr
.
startswith
(
"__"
):
setattr
(
self
,
attr
,
recursive_apply
(
getattr
(
self
,
attr
),
lambda
x
:
_to
(
x
,
device
)
),
)
return
self
def
_minibatch_str
(
minibatch
:
MiniBatch
)
->
str
:
final_str
=
""
...
...
tests/python/pytorch/graphbolt/test_base.py
View file @
617979d6
...
...
@@ -4,6 +4,7 @@ import unittest
import
backend
as
F
import
dgl.graphbolt
as
gb
import
gb_test_utils
import
pytest
import
torch
...
...
@@ -23,6 +24,58 @@ def test_CopyTo():
assert
data
.
device
.
type
==
"cuda"
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyToWithMiniBatches
():
N
=
16
B
=
2
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
N
),
names
=
"seed_nodes"
)
graph
=
gb_test_utils
.
rand_csc_graph
(
100
,
0.15
)
features
=
{}
keys
=
[(
"node"
,
None
,
"a"
),
(
"node"
,
None
,
"b"
)]
features
[
keys
[
0
]]
=
gb
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
features
[
keys
[
1
]]
=
gb
.
TorchBasedFeature
(
torch
.
randn
(
200
,
4
))
feature_store
=
gb
.
BasicFeatureStore
(
features
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
B
)
datapipe
=
gb
.
NeighborSampler
(
datapipe
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
datapipe
=
gb
.
FeatureFetcher
(
datapipe
,
feature_store
,
[
"a"
],
)
def
test_data_device
(
datapipe
):
for
data
in
datapipe
:
for
attr
in
dir
(
data
):
var
=
getattr
(
data
,
attr
)
if
(
not
callable
(
var
)
and
not
attr
.
startswith
(
"__"
)
and
hasattr
(
var
,
"device"
)
):
assert
var
.
device
.
type
==
"cuda"
# Invoke CopyTo via class constructor.
test_data_device
(
gb
.
CopyTo
(
datapipe
,
"cuda"
))
# Invoke CopyTo via functional form.
test_data_device
(
datapipe
.
copy_to
(
"cuda"
))
# Test for DGLMiniBatch.
datapipe
=
gb
.
DGLMiniBatchConverter
(
datapipe
)
# Invoke CopyTo via class constructor.
test_data_device
(
gb
.
CopyTo
(
datapipe
,
"cuda"
))
# Invoke CopyTo via functional form.
test_data_device
(
datapipe
.
copy_to
(
"cuda"
))
def
test_etype_tuple_to_str
():
"""Convert etype from tuple to string."""
# Test for expected input.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment