Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
09e5aa96
Unverified
Commit
09e5aa96
authored
Jul 12, 2023
by
Ramon Zhou
Committed by
GitHub
Jul 12, 2023
Browse files
[Graphbolt] Add pickle serialization support for Graphbolt SampledSubgraph (#5979)
parent
ced802d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
221 additions
and
1 deletion
+221
-1
graphbolt/include/graphbolt/sampled_subgraph.h
graphbolt/include/graphbolt/sampled_subgraph.h
+12
-0
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+12
-1
graphbolt/src/sampled_subgraph.cc
graphbolt/src/sampled_subgraph.cc
+99
-0
tests/python/pytorch/graphbolt/impl/test_c_sampled_subgraph.py
.../python/pytorch/graphbolt/impl/test_c_sampled_subgraph.py
+98
-0
No files found.
graphbolt/include/graphbolt/sampled_subgraph.h
View file @
09e5aa96
...
@@ -113,6 +113,18 @@ struct SampledSubgraph : torch::CustomClassHolder {
...
@@ -113,6 +113,18 @@ struct SampledSubgraph : torch::CustomClassHolder {
* subgraph.
* subgraph.
*/
*/
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
;
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
;
/**
* @brief Get graph state (for pickle serialization).
* @return A vector of Tensors.
*/
std
::
vector
<
torch
::
Tensor
>
GetState
();
/**
* @brief Set graph state (for pickle deserialization).
* @param state A vector of Tensors.
*/
void
SetState
(
std
::
vector
<
torch
::
Tensor
>&
state
);
};
};
}
// namespace sampling
}
// namespace sampling
...
...
graphbolt/src/python_binding.cc
View file @
09e5aa96
...
@@ -20,7 +20,18 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -20,7 +20,18 @@ TORCH_LIBRARY(graphbolt, m) {
.
def_readwrite
(
.
def_readwrite
(
"reverse_column_node_ids"
,
&
SampledSubgraph
::
reverse_column_node_ids
)
"reverse_column_node_ids"
,
&
SampledSubgraph
::
reverse_column_node_ids
)
.
def_readwrite
(
"reverse_edge_ids"
,
&
SampledSubgraph
::
reverse_edge_ids
)
.
def_readwrite
(
"reverse_edge_ids"
,
&
SampledSubgraph
::
reverse_edge_ids
)
.
def_readwrite
(
"type_per_edge"
,
&
SampledSubgraph
::
type_per_edge
);
.
def_readwrite
(
"type_per_edge"
,
&
SampledSubgraph
::
type_per_edge
)
.
def_pickle
(
// __getstate__
[](
const
c10
::
intrusive_ptr
<
SampledSubgraph
>&
self
)
->
std
::
vector
<
torch
::
Tensor
>
{
return
self
->
GetState
();
},
// __setstate__
[](
std
::
vector
<
torch
::
Tensor
>
state
)
->
c10
::
intrusive_ptr
<
SampledSubgraph
>
{
auto
g
=
c10
::
make_intrusive
<
SampledSubgraph
>
();
g
->
SetState
(
state
);
return
g
;
});
m
.
class_
<
CSCSamplingGraph
>
(
"CSCSamplingGraph"
)
m
.
class_
<
CSCSamplingGraph
>
(
"CSCSamplingGraph"
)
.
def
(
"num_nodes"
,
&
CSCSamplingGraph
::
NumNodes
)
.
def
(
"num_nodes"
,
&
CSCSamplingGraph
::
NumNodes
)
.
def
(
"num_edges"
,
&
CSCSamplingGraph
::
NumEdges
)
.
def
(
"num_edges"
,
&
CSCSamplingGraph
::
NumEdges
)
...
...
graphbolt/src/sampled_subgraph.cc
0 → 100644
View file @
09e5aa96
/**
* Copyright (c) 2023 by Contributors
* @file sampled_subgraph.cc
* @brief Source file of sampled subgraph.
*/
#include <graphbolt/sampled_subgraph.h>
#include <graphbolt/serialize.h>
#include <torch/torch.h>
#include <vector>
namespace
graphbolt
{
namespace
sampling
{
/**
* @brief Version number to indicate graph version in serialization and
* deserialization.
*/
static
constexpr
int64_t
kSampledSubgraphSerializeVersionNumber
=
1
;
std
::
vector
<
torch
::
Tensor
>
SampledSubgraph
::
GetState
()
{
std
::
vector
<
torch
::
Tensor
>
state
;
// Version number.
torch
::
Tensor
version_num_tensor
=
torch
::
ones
(
1
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
))
*
kSampledSubgraphSerializeVersionNumber
;
state
.
push_back
(
version_num_tensor
);
// Tensors.
state
.
push_back
(
indptr
);
state
.
push_back
(
indices
);
state
.
push_back
(
reverse_column_node_ids
);
// Optional tensors.
static
torch
::
Tensor
true_tensor
=
torch
::
ones
(
1
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
));
static
torch
::
Tensor
false_tensor
=
torch
::
zeros
(
1
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
));
if
(
reverse_row_node_ids
.
has_value
())
{
state
.
push_back
(
true_tensor
);
state
.
push_back
(
reverse_row_node_ids
.
value
());
}
else
{
state
.
push_back
(
false_tensor
);
}
if
(
reverse_edge_ids
.
has_value
())
{
state
.
push_back
(
true_tensor
);
state
.
push_back
(
reverse_edge_ids
.
value
());
}
else
{
state
.
push_back
(
false_tensor
);
}
if
(
type_per_edge
.
has_value
())
{
state
.
push_back
(
true_tensor
);
state
.
push_back
(
type_per_edge
.
value
());
}
else
{
state
.
push_back
(
false_tensor
);
}
return
state
;
}
void
SampledSubgraph
::
SetState
(
std
::
vector
<
torch
::
Tensor
>&
state
)
{
// Iterator.
uint32_t
i
=
0
;
// Version number.
torch
::
Tensor
&
version_num_tensor
=
state
[
i
++
];
torch
::
Tensor
current_version_num_tensor
=
torch
::
ones
(
1
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
))
*
kSampledSubgraphSerializeVersionNumber
;
TORCH_CHECK
(
version_num_tensor
.
equal
(
current_version_num_tensor
),
"Version number mismatch when deserializing SampledSubgraph."
);
// Tensors.
indptr
=
state
[
i
++
];
indices
=
state
[
i
++
];
reverse_column_node_ids
=
state
[
i
++
];
// Optional tensors.
static
torch
::
Tensor
true_tensor
=
torch
::
ones
(
1
,
torch
::
TensorOptions
().
dtype
(
torch
::
kInt32
));
reverse_row_node_ids
=
torch
::
nullopt
;
reverse_edge_ids
=
torch
::
nullopt
;
type_per_edge
=
torch
::
nullopt
;
if
(
state
[
i
++
].
equal
(
true_tensor
))
{
reverse_row_node_ids
=
state
[
i
++
];
}
if
(
state
[
i
++
].
equal
(
true_tensor
))
{
reverse_edge_ids
=
state
[
i
++
];
}
if
(
state
[
i
++
].
equal
(
true_tensor
))
{
type_per_edge
=
state
[
i
++
];
}
}
}
// namespace sampling
}
// namespace graphbolt
tests/python/pytorch/graphbolt/impl/test_c_sampled_subgraph.py
0 → 100644
View file @
09e5aa96
import
multiprocessing
as
mp
import
unittest
import
backend
as
F
import
dgl
import
dgl.graphbolt
as
gb
import
torch
def
subprocess_entry
(
queue
,
barrier
):
num_nodes
=
5
num_edges
=
12
indptr
=
torch
.
LongTensor
([
0
,
3
,
5
,
7
,
9
,
12
])
indices
=
torch
.
LongTensor
([
0
,
1
,
4
,
2
,
3
,
0
,
1
,
1
,
2
,
0
,
3
,
4
])
type_per_edge
=
torch
.
LongTensor
([
0
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
1
,
0
,
0
,
1
])
assert
indptr
[
-
1
]
==
num_edges
assert
indptr
[
-
1
]
==
len
(
indices
)
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
,
"n3"
:
2
}
etypes
=
{(
"n1"
,
"e1"
,
"n2"
):
0
,
(
"n1"
,
"e2"
,
"n3"
):
1
}
metadata
=
gb
.
GraphMetadata
(
ntypes
,
etypes
)
# Construct CSCSamplingGraph.
graph
=
gb
.
from_csc
(
indptr
,
indices
,
type_per_edge
=
type_per_edge
,
metadata
=
metadata
)
adjs
=
[]
seeds
=
torch
.
arange
(
5
)
# Sampling.
for
hop
in
range
(
2
):
sg
=
graph
.
sample_neighbors
(
seeds
,
torch
.
LongTensor
([
2
]))
seeds
=
sg
.
indices
adjs
.
append
(
sg
)
# Send the data twice (back and forth) and then verify.
# Method get() and put() of mp.Queue is blocking by default.
# Step 1. Put the data.
queue
.
put
(
adjs
)
# Step 2. Another process gets the data.
# Step 3. Barrier. Wait for another process to get the data.
barrier
.
wait
()
# Step 4. Another process puts the data.
# Step 5. Get the data.
result
=
queue
.
get
()
# Step 6. Verification.
for
hop
in
range
(
2
):
# Tensors.
assert
torch
.
equal
(
adjs
[
hop
].
indptr
,
result
[
hop
].
indptr
)
assert
torch
.
equal
(
adjs
[
hop
].
indices
,
result
[
hop
].
indices
)
assert
torch
.
equal
(
adjs
[
hop
].
reverse_column_node_ids
,
result
[
hop
].
reverse_column_node_ids
,
)
# Optional tensors.
assert
(
adjs
[
hop
].
reverse_row_node_ids
is
None
and
adjs
[
hop
].
reverse_row_node_ids
is
None
)
or
torch
.
equal
(
adjs
[
hop
].
reverse_row_node_ids
,
result
[
hop
].
reverse_row_node_ids
)
assert
(
adjs
[
hop
].
reverse_edge_ids
is
None
and
result
[
hop
].
reverse_edge_ids
is
None
)
or
torch
.
equal
(
adjs
[
hop
].
reverse_edge_ids
,
result
[
hop
].
reverse_edge_ids
)
assert
(
adjs
[
hop
].
type_per_edge
is
None
and
result
[
hop
].
type_per_edge
is
None
)
or
torch
.
equal
(
adjs
[
hop
].
type_per_edge
,
result
[
hop
].
type_per_edge
)
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"gpu"
,
reason
=
"Graph is CPU only at present."
,
)
def
test_subgraph_serialization
():
# Create a sub-process.
queue
=
mp
.
Queue
()
barrier
=
mp
.
Barrier
(
2
)
proc
=
mp
.
Process
(
target
=
subprocess_entry
,
args
=
(
queue
,
barrier
))
proc
.
start
()
# Send the data twice (back and forth) and then verify.
# Method get() and put() of mp.Queue is blocking by default.
# Step 1. Another process puts the data.
# Step 2. Get the data. This operation will block if the queue is empty.
items
=
queue
.
get
()
# Step 3. Barrier.
barrier
.
wait
()
# Step 4. Put the data again.
queue
.
put
(
items
)
# Step 5. Another process gets the final data.
# Step 6. Wait for another process to end
proc
.
join
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment