Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8cf5ad84
Unverified
Commit
8cf5ad84
authored
Feb 07, 2024
by
Ramon Zhou
Committed by
GitHub
Feb 07, 2024
Browse files
[GraphBolt] Add `to_pyg_data` for MiniBatch (#7076)
parent
0504bc2c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
128 additions
and
1 deletion
+128
-1
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+45
-1
tests/python/pytorch/graphbolt/test_minibatch.py
tests/python/pytorch/graphbolt/test_minibatch.py
+83
-0
No files found.
python/dgl/graphbolt/minibatch.py
View file @
8cf5ad84
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
import
dgl
import
dgl
from
dgl.utils
import
recursive_apply
from
dgl.utils
import
recursive_apply
from
.base
import
etype_str_to_tuple
from
.base
import
etype_str_to_tuple
,
expand_indptr
from
.internal
import
get_attributes
from
.internal
import
get_attributes
from
.sampled_subgraph
import
SampledSubgraph
from
.sampled_subgraph
import
SampledSubgraph
...
@@ -474,6 +474,50 @@ class MiniBatch:
...
@@ -474,6 +474,50 @@ class MiniBatch:
else
:
else
:
return
None
return
None
def
to_pyg_data
(
self
):
"""Construct a PyG Data from `MiniBatch`. This function only supports
node classification task on a homogeneous graph and the number of
features cannot be more than one.
"""
from
torch_geometric.data
import
Data
if
self
.
sampled_subgraphs
is
None
:
edge_index
=
None
else
:
col_nodes
=
[]
row_nodes
=
[]
for
subgraph
in
self
.
sampled_subgraphs
:
if
subgraph
is
None
:
continue
sampled_csc
=
subgraph
.
sampled_csc
indptr
=
sampled_csc
.
indptr
indices
=
sampled_csc
.
indices
expanded_indptr
=
expand_indptr
(
indptr
,
dtype
=
indices
.
dtype
,
output_size
=
len
(
indices
)
)
col_nodes
.
append
(
expanded_indptr
)
row_nodes
.
append
(
indices
)
col_nodes
=
torch
.
cat
(
col_nodes
)
row_nodes
=
torch
.
cat
(
row_nodes
)
edge_index
=
torch
.
unique
(
torch
.
stack
((
col_nodes
,
row_nodes
)),
dim
=
1
)
if
self
.
node_features
is
None
:
node_features
=
None
else
:
assert
(
len
(
self
.
node_features
)
==
1
),
"`to_pyg_data` only supports single feature homogeneous graph."
node_features
=
next
(
iter
(
self
.
node_features
.
values
()))
pyg_data
=
Data
(
x
=
node_features
,
edge_index
=
edge_index
,
y
=
self
.
labels
,
)
return
pyg_data
def
to
(
self
,
device
:
torch
.
device
):
# pylint: disable=invalid-name
def
to
(
self
,
device
:
torch
.
device
):
# pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""
"""Copy `MiniBatch` to the specified device using reflection."""
...
...
tests/python/pytorch/graphbolt/test_minibatch.py
View file @
8cf5ad84
...
@@ -859,3 +859,86 @@ def test_dgl_link_predication_hetero(mode):
...
@@ -859,3 +859,86 @@ def test_dgl_link_predication_hetero(mode):
minibatch
.
negative_node_pairs
[
etype
][
1
],
minibatch
.
negative_node_pairs
[
etype
][
1
],
minibatch
.
compacted_negative_dsts
[
etype
],
minibatch
.
compacted_negative_dsts
[
etype
],
)
)
def
test_to_pyg_data
():
test_subgraph_a
=
gb
.
SampledSubgraphImpl
(
sampled_csc
=
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
3
,
5
,
6
]),
indices
=
torch
.
tensor
([
0
,
1
,
2
,
2
,
1
,
2
]),
),
original_column_node_ids
=
torch
.
tensor
([
10
,
11
,
12
,
13
]),
original_row_node_ids
=
torch
.
tensor
([
19
,
20
,
21
,
22
,
25
,
30
]),
original_edge_ids
=
torch
.
tensor
([
10
,
11
,
12
,
13
]),
)
test_subgraph_b
=
gb
.
SampledSubgraphImpl
(
sampled_csc
=
gb
.
CSCFormatBase
(
indptr
=
torch
.
tensor
([
0
,
1
,
3
]),
indices
=
torch
.
tensor
([
1
,
2
,
0
]),
),
original_row_node_ids
=
torch
.
tensor
([
10
,
11
,
12
]),
original_edge_ids
=
torch
.
tensor
([
10
,
15
,
17
]),
original_column_node_ids
=
torch
.
tensor
([
10
,
11
]),
)
expected_edge_index
=
torch
.
tensor
(
[[
0
,
0
,
1
,
1
,
1
,
2
,
2
,
3
],
[
0
,
1
,
0
,
1
,
2
,
1
,
2
,
2
]]
)
expected_node_features
=
torch
.
tensor
([[
1
],
[
2
],
[
3
],
[
4
]])
expected_labels
=
torch
.
tensor
([
0
,
1
])
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_subgraph_a
,
test_subgraph_b
],
node_features
=
{
"feat"
:
expected_node_features
},
labels
=
expected_labels
,
)
pyg_data
=
test_minibatch
.
to_pyg_data
()
pyg_data
.
validate
()
assert
torch
.
equal
(
pyg_data
.
edge_index
,
expected_edge_index
)
assert
torch
.
equal
(
pyg_data
.
x
,
expected_node_features
)
assert
torch
.
equal
(
pyg_data
.
y
,
expected_labels
)
# Test with sampled_csc as None.
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
None
,
node_features
=
{
"feat"
:
expected_node_features
},
labels
=
expected_labels
,
)
pyg_data
=
test_minibatch
.
to_pyg_data
()
assert
pyg_data
.
edge_index
is
None
,
"Edge index should be none."
# Test with node_features as None.
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_subgraph_a
],
node_features
=
None
,
labels
=
expected_labels
,
)
pyg_data
=
test_minibatch
.
to_pyg_data
()
assert
pyg_data
.
x
is
None
,
"Node features should be None."
# Test with labels as None.
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_subgraph_a
],
node_features
=
{
"feat"
:
expected_node_features
},
labels
=
None
,
)
pyg_data
=
test_minibatch
.
to_pyg_data
()
assert
pyg_data
.
y
is
None
,
"Labels should be None."
# Test with multiple features.
test_minibatch
=
gb
.
MiniBatch
(
sampled_subgraphs
=
[
test_subgraph_a
],
node_features
=
{
"feat"
:
expected_node_features
,
"extra_feat"
:
torch
.
tensor
([[
3
],
[
4
]]),
},
labels
=
expected_labels
,
)
try
:
pyg_data
=
test_minibatch
.
to_pyg_data
()
assert
(
pyg_data
.
x
is
None
,
),
"Multiple features case should raise an error."
except
AssertionError
as
e
:
assert
(
str
(
e
)
==
"`to_pyg_data` only supports single feature homogeneous graph."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment