Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
49e336d2
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b47f5115da50769e554501560540969e02585adc"
Unverified
Commit
49e336d2
authored
Jul 04, 2023
by
Rhett Ying
Committed by
GitHub
Jul 04, 2023
Browse files
[GraphBolt] enable to convert DGLGraph to CSCSamplingGraph (#5948)
parent
cdf65f4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
94 additions
and
0 deletions
+94
-0
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+28
-0
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+66
-0
No files found.
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
49e336d2
...
@@ -7,6 +7,10 @@ from typing import Dict, Optional, Tuple
...
@@ -7,6 +7,10 @@ from typing import Dict, Optional, Tuple
import
torch
import
torch
from
...base
import
ETYPE
from
...convert
import
to_homogeneous
from
...heterograph
import
DGLGraph
class
GraphMetadata
:
class
GraphMetadata
:
r
"""Class for metadata of csc sampling graph."""
r
"""Class for metadata of csc sampling graph."""
...
@@ -452,3 +456,27 @@ def save_csc_sampling_graph(graph, filename):
...
@@ -452,3 +456,27 @@ def save_csc_sampling_graph(graph, filename):
metadata_filename
,
arcname
=
os
.
path
.
basename
(
metadata_filename
)
metadata_filename
,
arcname
=
os
.
path
.
basename
(
metadata_filename
)
)
)
print
(
f
"CSCSamplingGraph has been saved to
{
filename
}
."
)
print
(
f
"CSCSamplingGraph has been saved to
{
filename
}
."
)
def
from_dglgraph
(
g
:
DGLGraph
)
->
CSCSamplingGraph
:
"""Convert a DGLGraph to CSCSamplingGraph."""
homo_g
,
ntype_count
,
_
=
to_homogeneous
(
g
,
return_count
=
True
)
# Initialize metadata.
node_type_to_id
=
{
ntype
:
g
.
get_ntype_id
(
ntype
)
for
ntype
in
g
.
ntypes
}
edge_type_to_id
=
{
etype
:
g
.
get_etype_id
(
etype
)
for
etype
in
g
.
canonical_etypes
}
metadata
=
GraphMetadata
(
node_type_to_id
,
edge_type_to_id
)
# Obtain CSC matrix.
indptr
,
indices
,
_
=
homo_g
.
adj_tensors
(
"csc"
)
ntype_count
.
insert
(
0
,
0
)
node_type_offset
=
torch
.
cumsum
(
torch
.
LongTensor
(
ntype_count
),
0
)
type_per_edge
=
homo_g
.
edata
[
ETYPE
]
return
CSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
from_csc
(
indptr
,
indices
,
node_type_offset
,
type_per_edge
),
metadata
,
)
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
49e336d2
...
@@ -4,10 +4,12 @@ import unittest
...
@@ -4,10 +4,12 @@ import unittest
import
backend
as
F
import
backend
as
F
import
dgl
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
pytest
import
pytest
import
torch
import
torch
from
scipy
import
sparse
as
spsp
torch
.
manual_seed
(
3407
)
torch
.
manual_seed
(
3407
)
...
@@ -733,6 +735,70 @@ def test_hetero_graph_on_shared_memory(
...
@@ -733,6 +735,70 @@ def test_hetero_graph_on_shared_memory(
assert
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
assert
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph on GPU is not supported yet."
,
)
def
test_from_dglgraph_homogeneous
():
dgl_g
=
dgl
.
rand_graph
(
1000
,
10
*
1000
)
gb_g
=
gb
.
from_dglgraph
(
dgl_g
)
assert
gb_g
.
num_nodes
==
dgl_g
.
num_nodes
()
assert
gb_g
.
num_edges
==
dgl_g
.
num_edges
()
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
1000
]))
assert
torch
.
all
(
gb_g
.
type_per_edge
==
0
)
assert
gb_g
.
metadata
.
node_type_to_id
==
{
"_N"
:
0
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{(
"_N"
,
"_E"
,
"_N"
):
0
}
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph on GPU is not supported yet."
,
)
def
test_from_dglgraph_heterogeneous
():
def
create_random_hetero
():
num_nodes
=
{
"n1"
:
1000
,
"n2"
:
1010
,
"n3"
:
1020
}
etypes
=
[
(
"n1"
,
"r12"
,
"n2"
),
(
"n2"
,
"r21"
,
"n1"
),
(
"n1"
,
"r13"
,
"n3"
),
(
"n2"
,
"r23"
,
"n3"
),
]
edges
=
{}
for
etype
in
etypes
:
src_ntype
,
_
,
dst_ntype
=
etype
arr
=
spsp
.
random
(
num_nodes
[
src_ntype
],
num_nodes
[
dst_ntype
],
density
=
0.001
,
format
=
"coo"
,
random_state
=
100
,
)
edges
[
etype
]
=
(
arr
.
row
,
arr
.
col
)
return
dgl
.
heterograph
(
edges
,
num_nodes
)
dgl_g
=
create_random_hetero
()
gb_g
=
gb
.
from_dglgraph
(
dgl_g
)
assert
gb_g
.
num_nodes
==
dgl_g
.
num_nodes
()
assert
gb_g
.
num_edges
==
dgl_g
.
num_edges
()
assert
torch
.
equal
(
gb_g
.
node_type_offset
,
torch
.
tensor
([
0
,
1000
,
2010
,
3030
])
)
assert
torch
.
all
(
gb_g
.
type_per_edge
[:
-
1
]
<=
gb_g
.
type_per_edge
[
1
:])
assert
gb_g
.
metadata
.
node_type_to_id
==
{
"n1"
:
0
,
"n2"
:
1
,
"n3"
:
2
,
}
assert
gb_g
.
metadata
.
edge_type_to_id
==
{
(
"n1"
,
"r12"
,
"n2"
):
0
,
(
"n1"
,
"r13"
,
"n3"
):
1
,
(
"n2"
,
"r21"
,
"n1"
):
2
,
(
"n2"
,
"r23"
,
"n3"
):
3
,
}
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_sample_neighbors
()
test_sample_neighbors
()
test_sample_neighbors_replace
(
True
,
12
)
test_sample_neighbors_replace
(
True
,
12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment