Unverified Commit 807a753c authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

Revert "[Graphbolt] Add pickle serialization support for Graphbolt SampledSubgraph" (#5997)

parent ed0d6416
...@@ -113,18 +113,6 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -113,18 +113,6 @@ struct SampledSubgraph : torch::CustomClassHolder {
* subgraph. * subgraph.
*/ */
torch::optional<torch::Tensor> type_per_edge; torch::optional<torch::Tensor> type_per_edge;
/**
* @brief Get graph state (for pickle serialization).
* @return A vector of Tensors.
*/
std::vector<torch::Tensor> GetState();
/**
* @brief Set graph state (for pickle deserialization).
* @param state A vector of Tensors.
*/
void SetState(std::vector<torch::Tensor>& state);
}; };
} // namespace sampling } // namespace sampling
......
...@@ -20,18 +20,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -20,18 +20,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def_readwrite( .def_readwrite(
"reverse_column_node_ids", &SampledSubgraph::reverse_column_node_ids) "reverse_column_node_ids", &SampledSubgraph::reverse_column_node_ids)
.def_readwrite("reverse_edge_ids", &SampledSubgraph::reverse_edge_ids) .def_readwrite("reverse_edge_ids", &SampledSubgraph::reverse_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge) .def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<SampledSubgraph>& self)
-> std::vector<torch::Tensor> { return self->GetState(); },
// __setstate__
[](std::vector<torch::Tensor> state)
-> c10::intrusive_ptr<SampledSubgraph> {
auto g = c10::make_intrusive<SampledSubgraph>();
g->SetState(state);
return g;
});
m.class_<CSCSamplingGraph>("CSCSamplingGraph") m.class_<CSCSamplingGraph>("CSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes) .def("num_nodes", &CSCSamplingGraph::NumNodes)
.def("num_edges", &CSCSamplingGraph::NumEdges) .def("num_edges", &CSCSamplingGraph::NumEdges)
......
/**
* Copyright (c) 2023 by Contributors
* @file sampled_subgraph.cc
* @brief Source file of sampled subgraph.
*/
#include <graphbolt/sampled_subgraph.h>
#include <graphbolt/serialize.h>
#include <torch/torch.h>
#include <vector>
namespace graphbolt {
namespace sampling {
/**
* @brief Version number to indicate graph version in serialization and
* deserialization.
*/
static constexpr int64_t kSampledSubgraphSerializeVersionNumber = 1;
std::vector<torch::Tensor> SampledSubgraph::GetState() {
std::vector<torch::Tensor> state;
// Version number.
torch::Tensor version_num_tensor =
torch::ones(1, torch::TensorOptions().dtype(torch::kInt64)) *
kSampledSubgraphSerializeVersionNumber;
state.push_back(version_num_tensor);
// Tensors.
state.push_back(indptr);
state.push_back(indices);
state.push_back(reverse_column_node_ids);
// Optional tensors.
static torch::Tensor true_tensor =
torch::ones(1, torch::TensorOptions().dtype(torch::kInt32));
static torch::Tensor false_tensor =
torch::zeros(1, torch::TensorOptions().dtype(torch::kInt32));
if (reverse_row_node_ids.has_value()) {
state.push_back(true_tensor);
state.push_back(reverse_row_node_ids.value());
} else {
state.push_back(false_tensor);
}
if (reverse_edge_ids.has_value()) {
state.push_back(true_tensor);
state.push_back(reverse_edge_ids.value());
} else {
state.push_back(false_tensor);
}
if (type_per_edge.has_value()) {
state.push_back(true_tensor);
state.push_back(type_per_edge.value());
} else {
state.push_back(false_tensor);
}
return state;
}
void SampledSubgraph::SetState(std::vector<torch::Tensor>& state) {
// Iterator.
uint32_t i = 0;
// Version number.
torch::Tensor& version_num_tensor = state[i++];
torch::Tensor current_version_num_tensor =
torch::ones(1, torch::TensorOptions().dtype(torch::kInt64)) *
kSampledSubgraphSerializeVersionNumber;
TORCH_CHECK(
version_num_tensor.equal(current_version_num_tensor),
"Version number mismatch when deserializing SampledSubgraph.");
// Tensors.
indptr = state[i++];
indices = state[i++];
reverse_column_node_ids = state[i++];
// Optional tensors.
static torch::Tensor true_tensor =
torch::ones(1, torch::TensorOptions().dtype(torch::kInt32));
reverse_row_node_ids = torch::nullopt;
reverse_edge_ids = torch::nullopt;
type_per_edge = torch::nullopt;
if (state[i++].equal(true_tensor)) {
reverse_row_node_ids = state[i++];
}
if (state[i++].equal(true_tensor)) {
reverse_edge_ids = state[i++];
}
if (state[i++].equal(true_tensor)) {
type_per_edge = state[i++];
}
}
} // namespace sampling
} // namespace graphbolt
import multiprocessing as mp
import unittest
import backend as F
import dgl
import dgl.graphbolt as gb
import torch
def subprocess_entry(queue, barrier):
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
ntypes = {"n1": 0, "n2": 1, "n3": 2}
etypes = {("n1", "e1", "n2"): 0, ("n1", "e2", "n3"): 1}
metadata = gb.GraphMetadata(ntypes, etypes)
# Construct CSCSamplingGraph.
graph = gb.from_csc(
indptr, indices, type_per_edge=type_per_edge, metadata=metadata
)
adjs = []
seeds = torch.arange(5)
# Sampling.
for hop in range(2):
sg = graph.sample_neighbors(seeds, torch.LongTensor([2]))
seeds = sg.indices
adjs.append(sg)
# Send the data twice (back and forth) and then verify.
# Method get() and put() of mp.Queue is blocking by default.
# Step 1. Put the data.
queue.put(adjs)
# Step 2. Another process gets the data.
# Step 3. Barrier. Wait for another process to get the data.
barrier.wait()
# Step 4. Another process puts the data.
# Step 5. Get the data.
result = queue.get()
# Step 6. Verification.
for hop in range(2):
# Tensors.
assert torch.equal(adjs[hop].indptr, result[hop].indptr)
assert torch.equal(adjs[hop].indices, result[hop].indices)
assert torch.equal(
adjs[hop].reverse_column_node_ids,
result[hop].reverse_column_node_ids,
)
# Optional tensors.
assert (
adjs[hop].reverse_row_node_ids is None
and adjs[hop].reverse_row_node_ids is None
) or torch.equal(
adjs[hop].reverse_row_node_ids, result[hop].reverse_row_node_ids
)
assert (
adjs[hop].reverse_edge_ids is None
and result[hop].reverse_edge_ids is None
) or torch.equal(
adjs[hop].reverse_edge_ids, result[hop].reverse_edge_ids
)
assert (
adjs[hop].type_per_edge is None
and result[hop].type_per_edge is None
) or torch.equal(adjs[hop].type_per_edge, result[hop].type_per_edge)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_subgraph_serialization():
# Create a sub-process.
queue = mp.Queue()
barrier = mp.Barrier(2)
proc = mp.Process(target=subprocess_entry, args=(queue, barrier))
proc.start()
# Send the data twice (back and forth) and then verify.
# Method get() and put() of mp.Queue is blocking by default.
# Step 1. Another process puts the data.
# Step 2. Get the data. This operation will block if the queue is empty.
items = queue.get()
# Step 3. Barrier.
barrier.wait()
# Step 4. Put the data again.
queue.put(items)
# Step 5. Another process gets the final data.
# Step 6. Wait for another process to end
proc.join()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment