Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
02443df1
"tests/vscode:/vscode.git/clone" did not exist on "099b173f6f678d576a727ee4ad170599ec466f4e"
Unverified
Commit
02443df1
authored
Oct 24, 2023
by
Ramon Zhou
Committed by
GitHub
Oct 24, 2023
Browse files
[GraphBolt] Add to function for SampledSubgraph (#6480)
parent
aff6b685
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
3 deletions
+76
-3
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+4
-2
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+19
-1
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
...thon/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
+53
-0
No files found.
python/dgl/graphbolt/base.py
View file @
02443df1
...
...
@@ -76,7 +76,9 @@ def etype_str_to_tuple(c_etype):
return
ret
def
_to
(
x
,
device
):
def
apply_to
(
x
,
device
):
"""Apply `to` function to object x only if it has `to`."""
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
...
...
@@ -107,5 +109,5 @@ class CopyTo(IterDataPipe):
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
data
=
recursive_apply
(
data
,
_to
,
self
.
device
)
data
=
recursive_apply
(
data
,
apply
_to
,
self
.
device
)
yield
data
python/dgl/graphbolt/sampled_subgraph.py
View file @
02443df1
...
...
@@ -4,7 +4,9 @@ from typing import Dict, Tuple, Union
import
torch
from
.base
import
etype_str_to_tuple
,
isin
from
dgl.utils
import
recursive_apply
from
.base
import
apply_to
,
etype_str_to_tuple
,
isin
__all__
=
[
"SampledSubgraph"
]
...
...
@@ -189,6 +191,22 @@ class SampledSubgraph:
)
return
calling_class
(
*
_slice_subgraph
(
self
,
index
))
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `SampledSubgraph` to the specified device using reflection."""
for
attr
in
dir
(
self
):
# Only copy member variables.
if
not
callable
(
getattr
(
self
,
attr
))
and
not
attr
.
startswith
(
"__"
):
setattr
(
self
,
attr
,
recursive_apply
(
getattr
(
self
,
attr
),
lambda
x
:
apply_to
(
x
,
device
)
),
)
return
self
def
_to_reverse_ids
(
node_pair
,
original_row_node_ids
,
original_column_node_ids
):
u
,
v
=
node_pair
...
...
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
View file @
02443df1
import
unittest
import
backend
as
F
import
pytest
import
torch
...
...
@@ -132,3 +135,53 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
_assert_container_equal
(
result
.
original_row_node_ids
,
expected_row_node_ids
)
_assert_container_equal
(
result
.
original_edge_ids
,
expected_edge_ids
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
reason
=
"`to` function needs GPU to test."
,
)
def
test_sampled_subgraph_to_device
():
# Initialize data.
node_pairs
=
{
"A:relation:B"
:
(
torch
.
tensor
([
0
,
1
,
2
]),
torch
.
tensor
([
2
,
1
,
0
]),
)
}
original_row_node_ids
=
{
"A"
:
torch
.
tensor
([
13
,
14
,
15
]),
}
src_to_exclude
=
torch
.
tensor
([
15
,
13
])
original_column_node_ids
=
{
"B"
:
torch
.
tensor
([
10
,
11
,
12
]),
}
dst_to_exclude
=
torch
.
tensor
([
10
,
12
])
original_edge_ids
=
{
"A:relation:B"
:
torch
.
tensor
([
19
,
20
,
21
])}
subgraph
=
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
original_column_node_ids
=
original_column_node_ids
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
original_edge_ids
,
)
edges_to_exclude
=
{
"A:relation:B"
:
(
src_to_exclude
,
dst_to_exclude
,
)
}
graph
=
subgraph
.
exclude_edges
(
edges_to_exclude
)
# Copy to device.
graph
=
graph
.
to
(
"cuda"
)
# Check.
for
key
in
graph
.
node_pairs
:
assert
graph
.
node_pairs
[
key
][
0
].
device
.
type
==
"cuda"
assert
graph
.
node_pairs
[
key
][
1
].
device
.
type
==
"cuda"
for
key
in
graph
.
original_column_node_ids
:
assert
graph
.
original_column_node_ids
[
key
].
device
.
type
==
"cuda"
for
key
in
graph
.
original_row_node_ids
:
assert
graph
.
original_row_node_ids
[
key
].
device
.
type
==
"cuda"
for
key
in
graph
.
original_edge_ids
:
assert
graph
.
original_edge_ids
[
key
].
device
.
type
==
"cuda"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment