"docs/vscode:/vscode.git/clone" did not exist on "f960468f50b732c0b8684cc69879232e094f1814"
Unverified Commit d387e8e3 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Add pickle methods for CSCSamplingGraph (#6199)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent ec8225df
......@@ -129,6 +129,21 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
*/
void Save(torch::serialize::OutputArchive& archive) const;
/**
* @brief Pickle method for deserializing.
* @param state The state of serialized CSCSamplingGraph.
*/
void SetState(
const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&
state);
/**
* @brief Pickle method for serializing.
* @returns The state of this CSCSamplingGraph.
*/
torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> GetState()
const;
/**
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
......
......@@ -22,6 +22,8 @@
namespace graphbolt {
namespace sampling {
static const int kPickleVersion = 6199;
CSCSamplingGraph::CSCSamplingGraph(
const torch::Tensor& indptr, const torch::Tensor& indices,
const torch::optional<torch::Tensor>& node_type_offset,
......@@ -97,6 +99,56 @@ void CSCSamplingGraph::Save(torch::serialize::OutputArchive& archive) const {
}
}
void CSCSamplingGraph::SetState(
const torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>&
state) {
// State is a dict of dicts. The tensor-type attributes are stored in the dict
// with key "independent_tensors". The dict-type attributes (edge_attributes)
// are stored directly with the their name as the key.
const auto& independent_tensors = state.at("independent_tensors");
TORCH_CHECK(
independent_tensors.at("version_number")
.equal(torch::tensor({kPickleVersion})),
"Version number mismatches when loading pickled CSCSamplingGraph.")
indptr_ = independent_tensors.at("indptr");
indices_ = independent_tensors.at("indices");
if (independent_tensors.find("node_type_offset") !=
independent_tensors.end()) {
node_type_offset_ = independent_tensors.at("node_type_offset");
}
if (independent_tensors.find("type_per_edge") != independent_tensors.end()) {
type_per_edge_ = independent_tensors.at("type_per_edge");
}
if (state.find("edge_attributes") != state.end()) {
edge_attributes_ = state.at("edge_attributes");
}
}
torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
CSCSamplingGraph::GetState() const {
// State is a dict of dicts. The tensor-type attributes are stored in the dict
// with key "independent_tensors". The dict-type attributes (edge_attributes)
// are stored directly with the their name as the key.
torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>> state;
torch::Dict<std::string, torch::Tensor> independent_tensors;
// Serialization version number. It indicates the serialization method of the
// whole state.
independent_tensors.insert("version_number", torch::tensor({kPickleVersion}));
independent_tensors.insert("indptr", indptr_);
independent_tensors.insert("indices", indices_);
if (node_type_offset_.has_value()) {
independent_tensors.insert("node_type_offset", node_type_offset_.value());
}
if (type_per_edge_.has_value()) {
independent_tensors.insert("type_per_edge", type_per_edge_.value());
}
state.insert("independent_tensors", independent_tensors);
if (edge_attributes_.has_value()) {
state.insert("edge_attributes", edge_attributes_.value());
}
return state;
}
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
using namespace torch::indexing;
......
......@@ -35,7 +35,21 @@ TORCH_LIBRARY(graphbolt, m) {
.def(
"sample_negative_edges_uniform",
&CSCSamplingGraph::SampleNegativeEdgesUniform)
.def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory);
.def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<CSCSamplingGraph>& self)
-> torch::Dict<
std::string, torch::Dict<std::string, torch::Tensor>> {
return self->GetState();
},
// __setstate__
[](torch::Dict<std::string, torch::Dict<std::string, torch::Tensor>>
state) -> c10::intrusive_ptr<CSCSamplingGraph> {
auto g = c10::make_intrusive<CSCSamplingGraph>();
g->SetState(state);
return g;
});
m.def("from_csc", &CSCSamplingGraph::FromCSC);
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
m.def("save_csc_sampling_graph", &SaveCSCSamplingGraph);
......
import os
import pickle
import tempfile
import unittest
......@@ -9,9 +11,11 @@ import dgl.graphbolt as gb
import gb_test_utils as gbt
import pytest
import torch
import torch.multiprocessing as mp
from scipy import sparse as spsp
torch.manual_seed(3407)
mp.set_sharing_strategy("file_system")
@unittest.skipIf(
......@@ -251,6 +255,116 @@ def test_load_save_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
def test_pickle_homo_graph(num_nodes, num_edges):
csc_indptr, indices = gbt.random_homo_graph(num_nodes, num_edges)
graph = gb.from_csc(csc_indptr, indices)
serialized = pickle.dumps(graph)
graph2 = pickle.loads(serialized)
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert graph.node_type_offset is None and graph2.node_type_offset is None
assert graph.type_per_edge is None and graph2.type_per_edge is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"num_nodes, num_edges", [(1, 1), (100, 1), (10, 50), (1000, 50000)]
)
@pytest.mark.parametrize("num_ntypes, num_etypes", [(1, 1), (3, 5), (100, 1)])
def test_pickle_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
edge_attributes = {
"a": torch.randn((num_edges,)),
"b": torch.randint(1, 10, (num_edges,)),
}
graph = gb.from_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
)
serialized = pickle.dumps(graph)
graph2 = pickle.loads(serialized)
assert graph.num_nodes == graph2.num_nodes
assert graph.num_edges == graph2.num_edges
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for i in graph.edge_attributes.keys():
assert torch.equal(graph.edge_attributes[i], graph2.edge_attributes[i])
def process_csc_sampling_graph_multiprocessing(graph):
return graph.num_nodes
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_multiprocessing():
num_nodes = 5
num_edges = 10
num_ntypes = 2
num_etypes = 3
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
edge_attributes = {
"a": torch.randn((num_edges,)),
}
graph = gb.from_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
)
p = mp.Process(
target=process_csc_sampling_graph_multiprocessing, args=(graph,)
)
p.start()
p.join()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
......@@ -786,6 +900,108 @@ def test_hetero_graph_on_shared_memory(
assert metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
def process_csc_sampling_graph_on_shared_memory(graph, data_queue, flag_queue):
# Backup the attributes.
csc_indptr = graph.csc_indptr.clone()
indices = graph.indices.clone()
node_type_offset = graph.node_type_offset.clone()
type_per_edge = graph.type_per_edge.clone()
# Change the value to random integers. Send the new value to the main
# process.
v = torch.randint_like(graph.csc_indptr, 100)
graph.csc_indptr[:] = v
data_queue.put(v.clone())
v = torch.randint_like(graph.indices, 100)
graph.indices[:] = v
data_queue.put(v.clone())
v = torch.randint_like(graph.node_type_offset, 100)
graph.node_type_offset[:] = v
data_queue.put(v.clone())
v = torch.randint_like(graph.type_per_edge, 100)
graph.type_per_edge[:] = v
data_queue.put(v.clone())
# Wait for the main process to finish.
flag_queue.get()
graph.csc_indptr[:] = csc_indptr
graph.indices[:] = indices
graph.node_type_offset[:] = node_type_offset
graph.type_per_edge[:] = type_per_edge
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_multiprocessing_with_shared_memory():
"""Test if two CSCSamplingGraphs are on the same shared memory after
spawning.
For now this code only works when the sharing strategy of
torch.multiprocessing is set to `file_system` at the beginning.
The cause is still yet to be found.
"""
num_nodes = 5
num_edges = 10
num_ntypes = 2
num_etypes = 3
(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
metadata,
) = gbt.random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes)
csc_indptr.share_memory_()
indices.share_memory_()
node_type_offset.share_memory_()
type_per_edge.share_memory_()
graph = gb.from_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
None,
metadata,
)
ctx = mp.get_context("spawn") # Use spawn method.
data_queue = ctx.Queue() # Used for sending graph.
flag_queue = ctx.Queue() # Used for sending finish signal.
p = ctx.Process(
target=process_csc_sampling_graph_on_shared_memory,
args=(graph, data_queue, flag_queue),
)
p.start()
try:
# Get data from the other process. Then check if the tensors here have
# the same data.
csc_indptr2 = data_queue.get()
assert torch.equal(graph.csc_indptr, csc_indptr2)
indices2 = data_queue.get()
assert torch.equal(graph.indices, indices2)
node_type_offset2 = data_queue.get()
assert torch.equal(graph.node_type_offset, node_type_offset2)
type_per_edge2 = data_queue.get()
assert torch.equal(graph.type_per_edge, type_per_edge2)
except:
raise
finally:
# Send a finish signal to end sub-process.
flag_queue.put(None)
p.join()
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph on GPU is not supported yet.",
......
......@@ -11,8 +11,6 @@ import torch
from torchdata.datapipes.iter import Mapper
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
# TODO (peizhou): Will enable windows test once CSCSamplingraph is pickleable.
def test_DataLoader():
N = 40
B = 4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment