Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
70af8f0d
Unverified
Commit
70af8f0d
authored
May 24, 2023
by
Rhett Ying
Committed by
GitHub
May 24, 2023
Browse files
[GraphBolt] add support for load/save CSCSamplingGraph in python level (#5733)
parent
6862e372
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
123 additions
and
4 deletions
+123
-4
graphbolt/src/csc_sampling_graph.cc
graphbolt/src/csc_sampling_graph.cc
+22
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+3
-0
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
+35
-0
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
+63
-4
No files found.
graphbolt/src/csc_sampling_graph.cc
View file @
70af8f0d
...
@@ -48,12 +48,34 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
...
@@ -48,12 +48,34 @@ void CSCSamplingGraph::Load(torch::serialize::InputArchive& archive) {
"Magic numbers mismatch when loading CSCSamplingGraph."
);
"Magic numbers mismatch when loading CSCSamplingGraph."
);
indptr_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/indptr"
).
toTensor
();
indptr_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/indptr"
).
toTensor
();
indices_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/indices"
).
toTensor
();
indices_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/indices"
).
toTensor
();
if
(
read_from_archive
(
archive
,
"CSCSamplingGraph/has_node_type_offset"
)
.
toBool
())
{
node_type_offset_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/node_type_offset"
)
.
toTensor
();
}
if
(
read_from_archive
(
archive
,
"CSCSamplingGraph/has_type_per_edge"
)
.
toBool
())
{
type_per_edge_
=
read_from_archive
(
archive
,
"CSCSamplingGraph/type_per_edge"
).
toTensor
();
}
}
}
void
CSCSamplingGraph
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
void
CSCSamplingGraph
::
Save
(
torch
::
serialize
::
OutputArchive
&
archive
)
const
{
archive
.
write
(
"CSCSamplingGraph/magic_num"
,
kCSCSamplingGraphSerializeMagic
);
archive
.
write
(
"CSCSamplingGraph/magic_num"
,
kCSCSamplingGraphSerializeMagic
);
archive
.
write
(
"CSCSamplingGraph/indptr"
,
indptr_
);
archive
.
write
(
"CSCSamplingGraph/indptr"
,
indptr_
);
archive
.
write
(
"CSCSamplingGraph/indices"
,
indices_
);
archive
.
write
(
"CSCSamplingGraph/indices"
,
indices_
);
archive
.
write
(
"CSCSamplingGraph/has_node_type_offset"
,
node_type_offset_
.
has_value
());
if
(
node_type_offset_
)
{
archive
.
write
(
"CSCSamplingGraph/node_type_offset"
,
node_type_offset_
.
value
());
}
archive
.
write
(
"CSCSamplingGraph/has_type_per_edge"
,
type_per_edge_
.
has_value
());
if
(
type_per_edge_
)
{
archive
.
write
(
"CSCSamplingGraph/type_per_edge"
,
type_per_edge_
.
value
());
}
}
}
}
// namespace sampling
}
// namespace sampling
...
...
graphbolt/src/python_binding.cc
View file @
70af8f0d
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
*/
*/
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/csc_sampling_graph.h>
#include <graphbolt/serialize.h>
namespace
graphbolt
{
namespace
graphbolt
{
namespace
sampling
{
namespace
sampling
{
...
@@ -18,6 +19,8 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -18,6 +19,8 @@ TORCH_LIBRARY(graphbolt, m) {
.
def
(
"node_type_offset"
,
&
CSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"node_type_offset"
,
&
CSCSamplingGraph
::
NodeTypeOffset
)
.
def
(
"type_per_edge"
,
&
CSCSamplingGraph
::
TypePerEdge
);
.
def
(
"type_per_edge"
,
&
CSCSamplingGraph
::
TypePerEdge
);
m
.
def
(
"from_csc"
,
&
CSCSamplingGraph
::
FromCSC
);
m
.
def
(
"from_csc"
,
&
CSCSamplingGraph
::
FromCSC
);
m
.
def
(
"load_csc_sampling_graph"
,
&
LoadCSCSamplingGraph
);
m
.
def
(
"save_csc_sampling_graph"
,
&
SaveCSCSamplingGraph
);
}
}
}
// namespace sampling
}
// namespace sampling
...
...
python/dgl/graphbolt/graph_storage/csc_sampling_graph.py
View file @
70af8f0d
"""CSC format sampling graph."""
"""CSC format sampling graph."""
# pylint: disable= invalid-name
# pylint: disable= invalid-name
import
os
import
tarfile
import
tempfile
from
typing
import
Dict
,
Optional
,
Tuple
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -254,3 +257,35 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
...
@@ -254,3 +257,35 @@ def _csc_sampling_graph_str(graph: CSCSamplingGraph) -> str:
final_str
=
prefix
+
_add_indent
(
final_str
,
len
(
prefix
))
final_str
=
prefix
+
_add_indent
(
final_str
,
len
(
prefix
))
return
final_str
return
final_str
def
load_csc_sampling_graph
(
filename
):
"""Load CSCSamplingGraph from tar file."""
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
with
tarfile
.
open
(
filename
,
"r"
)
as
archive
:
archive
.
extractall
(
temp_dir
)
graph_filename
=
os
.
path
.
join
(
temp_dir
,
"csc_sampling_graph.pt"
)
metadata_filename
=
os
.
path
.
join
(
temp_dir
,
"metadata.pt"
)
return
CSCSamplingGraph
(
torch
.
ops
.
graphbolt
.
load_csc_sampling_graph
(
graph_filename
),
torch
.
load
(
metadata_filename
),
)
def
save_csc_sampling_graph
(
graph
,
filename
):
"""Save CSCSamplingGraph to tar file."""
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
graph_filename
=
os
.
path
.
join
(
temp_dir
,
"csc_sampling_graph.pt"
)
torch
.
ops
.
graphbolt
.
save_csc_sampling_graph
(
graph
.
_c_csc_graph
,
graph_filename
)
metadata_filename
=
os
.
path
.
join
(
temp_dir
,
"metadata.pt"
)
torch
.
save
(
graph
.
metadata
,
metadata_filename
)
with
tarfile
.
open
(
filename
,
"w"
)
as
archive
:
archive
.
add
(
graph_filename
,
arcname
=
os
.
path
.
basename
(
graph_filename
)
)
archive
.
add
(
metadata_filename
,
arcname
=
os
.
path
.
basename
(
metadata_filename
)
)
print
(
f
"CSCSamplingGraph has been saved to
{
filename
}
."
)
tests/python/pytorch/graphbolt/test_csc_sampling_graph.py
View file @
70af8f0d
import
os
import
tempfile
import
unittest
import
unittest
import
backend
as
F
import
backend
as
F
...
@@ -211,7 +213,64 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
...
@@ -211,7 +213,64 @@ def test_node_type_offset_wrong_legnth(node_type_offset):
)
)
if
__name__
==
"__main__"
:
@
unittest
.
skipIf
(
test_empty_graph
(
10
)
F
.
_default_context_str
==
"gpu"
,
test_node_type_offset_wrong_legnth
(
torch
.
tensor
([
0
,
1
,
5
]))
reason
=
"Graph is CPU only at present."
,
test_hetero_graph
(
10
,
50
,
3
,
5
)
)
@
pytest
.
mark
.
parametrize
(
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
)
def
test_load_save_homo_graph
(
num_nodes
,
num_edges
):
csc_indptr
,
indices
=
random_homo_graph
(
num_nodes
,
num_edges
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
filename
=
os
.
path
.
join
(
test_dir
,
"csc_sampling_graph.tar"
)
gb
.
save_csc_sampling_graph
(
graph
,
filename
)
graph2
=
gb
.
load_csc_sampling_graph
(
filename
)
assert
graph
.
num_nodes
==
graph2
.
num_nodes
assert
graph
.
num_edges
==
graph2
.
num_edges
assert
torch
.
equal
(
graph
.
csc_indptr
,
graph2
.
csc_indptr
)
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
graph
.
metadata
is
None
and
graph2
.
metadata
is
None
assert
graph
.
node_type_offset
is
None
and
graph2
.
node_type_offset
is
None
assert
graph
.
type_per_edge
is
None
and
graph2
.
type_per_edge
is
None
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
@
pytest
.
mark
.
parametrize
(
"num_nodes, num_edges"
,
[(
1
,
1
),
(
100
,
1
),
(
10
,
50
),
(
1000
,
50000
)]
)
@
pytest
.
mark
.
parametrize
(
"num_ntypes, num_etypes"
,
[(
1
,
1
),
(
3
,
5
),
(
100
,
1
)])
def
test_load_save_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
):
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
,
)
=
random_hetero_graph
(
num_nodes
,
num_edges
,
num_ntypes
,
num_etypes
)
graph
=
gb
.
from_csc
(
csc_indptr
,
indices
,
node_type_offset
,
type_per_edge
,
metadata
)
with
tempfile
.
TemporaryDirectory
()
as
test_dir
:
filename
=
os
.
path
.
join
(
test_dir
,
"csc_sampling_graph.tar"
)
gb
.
save_csc_sampling_graph
(
graph
,
filename
)
graph2
=
gb
.
load_csc_sampling_graph
(
filename
)
assert
graph
.
num_nodes
==
graph2
.
num_nodes
assert
graph
.
num_edges
==
graph2
.
num_edges
assert
torch
.
equal
(
graph
.
csc_indptr
,
graph2
.
csc_indptr
)
assert
torch
.
equal
(
graph
.
indices
,
graph2
.
indices
)
assert
torch
.
equal
(
graph
.
node_type_offset
,
graph2
.
node_type_offset
)
assert
torch
.
equal
(
graph
.
type_per_edge
,
graph2
.
type_per_edge
)
assert
graph
.
metadata
.
node_type_to_id
==
graph2
.
metadata
.
node_type_to_id
assert
graph
.
metadata
.
edge_type_to_id
==
graph2
.
metadata
.
edge_type_to_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment