Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
32be4a8e
Unverified
Commit
32be4a8e
authored
Oct 09, 2023
by
LastWhisper
Committed by
GitHub
Oct 09, 2023
Browse files
[GraphBolt] Store `ORIGINAL_EDGE_ID` in CSCSamplingGraph's `edge_attributes`. (#6399)
parent
241760a5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
2 deletions
+48
-2
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+5
-2
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+43
-0
No files found.
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
32be4a8e
...
...
@@ -8,7 +8,7 @@ from typing import Dict, Optional, Union
import
torch
from
...base
import
ETYPE
from
...base
import
EID
,
ETYPE
from
...convert
import
to_homogeneous
from
...heterograph
import
DGLGraph
from
..base
import
etype_str_to_tuple
,
etype_tuple_to_str
,
ORIGINAL_EDGE_ID
...
...
@@ -810,13 +810,16 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
# Assign edge type according to the order of CSC matrix.
type_per_edge
=
None
if
is_homogeneous
else
homo_g
.
edata
[
ETYPE
][
edge_ids
]
# Assign edge attributes according to the original eids mapping.
edge_attributes
=
{
ORIGINAL_EDGE_ID
:
homo_g
.
edata
[
EID
][
edge_ids
]}
return
CSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
from_csc
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
,
None
,
edge_attributes
,
),
metadata
,
)
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
32be4a8e
...
...
@@ -1296,6 +1296,15 @@ def test_from_dglgraph_homogeneous():
dgl_g
=
dgl
.
rand_graph
(
1000
,
10
*
1000
)
gb_g
=
gb
.
from_dglgraph
(
dgl_g
,
is_homogeneous
=
True
)
# Get the COO representation of the CSCSamplingGraph.
num_columns
=
gb_g
.
csc_indptr
[
1
:]
-
gb_g
.
csc_indptr
[:
-
1
]
rows
=
gb_g
.
indices
columns
=
torch
.
arange
(
gb_g
.
total_num_nodes
).
repeat_interleave
(
num_columns
)
original_edge_ids
=
gb_g
.
edge_attributes
[
gb
.
ORIGINAL_EDGE_ID
]
assert
torch
.
all
(
dgl_g
.
edges
()[
0
][
original_edge_ids
]
==
rows
)
assert
torch
.
all
(
dgl_g
.
edges
()[
1
][
original_edge_ids
]
==
columns
)
assert
gb_g
.
total_num_nodes
==
dgl_g
.
num_nodes
()
assert
gb_g
.
total_num_edges
==
dgl_g
.
num_edges
()
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
1000
]))
...
...
@@ -1328,6 +1337,40 @@ def test_from_dglgraph_heterogeneous():
)
gb_g
=
gb
.
from_dglgraph
(
dgl_g
,
is_homogeneous
=
False
)
# `reverse_node_id` is used to map the node id in CSCSamplingGraph to the
# node id in Hetero-DGLGraph.
num_ntypes
=
gb_g
.
node_type_offset
[
1
:]
-
gb_g
.
node_type_offset
[:
-
1
]
reverse_node_id
=
torch
.
cat
([
torch
.
arange
(
num
)
for
num
in
num_ntypes
])
# Get the COO representation of the CSCSamplingGraph.
num_columns
=
gb_g
.
csc_indptr
[
1
:]
-
gb_g
.
csc_indptr
[:
-
1
]
rows
=
reverse_node_id
[
gb_g
.
indices
]
columns
=
reverse_node_id
[
torch
.
arange
(
gb_g
.
total_num_nodes
).
repeat_interleave
(
num_columns
)
]
# Check the order of etypes in DGLGraph is the same as CSCSamplingGraph.
assert
(
# Since the etypes in CSCSamplingGraph is "srctype:etype:dsttype",
# we need to split the string and get the middle part.
list
(
map
(
lambda
ss
:
ss
.
split
(
":"
)[
1
],
gb_g
.
metadata
.
edge_type_to_id
.
keys
(),
)
)
==
dgl_g
.
etypes
)
# Use ORIGINAL_EDGE_ID to check if the edge mapping is correct.
for
edge_idx
in
range
(
gb_g
.
total_num_edges
):
hetero_graph_idx
=
gb_g
.
type_per_edge
[
edge_idx
]
original_edge_id
=
gb_g
.
edge_attributes
[
gb
.
ORIGINAL_EDGE_ID
][
edge_idx
]
edge_type
=
dgl_g
.
etypes
[
hetero_graph_idx
]
dgl_edge_pairs
=
dgl_g
.
edges
(
etype
=
edge_type
)
assert
dgl_edge_pairs
[
0
][
original_edge_id
]
==
rows
[
edge_idx
]
assert
dgl_edge_pairs
[
1
][
original_edge_id
]
==
columns
[
edge_idx
]
assert
gb_g
.
total_num_nodes
==
dgl_g
.
num_nodes
()
assert
gb_g
.
total_num_edges
==
dgl_g
.
num_edges
()
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
6
,
12
,
18
,
25
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment