Unverified Commit 2c325b2d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add node/edge_type_to_id into pickle (#6701)

parent cbb6f502
......@@ -232,6 +232,12 @@ void FusedCSCSamplingGraph::SetState(
if (independent_tensors.find("type_per_edge") != independent_tensors.end()) {
type_per_edge_ = independent_tensors.at("type_per_edge");
}
if (state.find("node_type_to_id") != state.end()) {
node_type_to_id_ = DetensorizeDict(state.at("node_type_to_id"));
}
if (state.find("edge_type_to_id") != state.end()) {
edge_type_to_id_ = DetensorizeDict(state.at("edge_type_to_id"));
}
if (state.find("edge_attributes") != state.end()) {
edge_attributes_ = state.at("edge_attributes");
}
......@@ -256,6 +262,12 @@ FusedCSCSamplingGraph::GetState() const {
independent_tensors.insert("type_per_edge", type_per_edge_.value());
}
state.insert("independent_tensors", independent_tensors);
if (node_type_to_id_.has_value()) {
state.insert("node_type_to_id", TensorizeDict(node_type_to_id_).value());
}
if (edge_type_to_id_.has_value()) {
state.insert("edge_type_to_id", TensorizeDict(edge_type_to_id_).value());
}
if (edge_attributes_.has_value()) {
state.insert("edge_attributes", edge_attributes_.value());
}
......
......@@ -35,11 +35,15 @@ TORCH_LIBRARY(graphbolt, m) {
.def("indices", &FusedCSCSamplingGraph::Indices)
.def("node_type_offset", &FusedCSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
.def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID)
.def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID)
.def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &FusedCSCSamplingGraph::SetIndices)
.def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset)
.def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
.def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID)
.def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID)
.def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
......
......@@ -255,6 +255,44 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge type tensor if present."""
self._c_csc_graph.set_type_per_edge(type_per_edge)
@property
def node_type_to_id(self) -> Optional[Dict[str, int]]:
"""Returns the node type to id dictionary if present.
Returns
-------
Dict[str, int] or None
If present, returns a dictionary mapping node type to node type
id.
"""
return self._c_csc_graph.node_type_to_id()
@node_type_to_id.setter
def node_type_to_id(
self, node_type_to_id: Optional[Dict[str, int]]
) -> None:
"""Sets the node type to id dictionary if present."""
self._c_csc_graph.set_node_type_to_id(node_type_to_id)
@property
def edge_type_to_id(self) -> Optional[Dict[str, int]]:
"""Returns the edge type to id dictionary if present.
Returns
-------
Dict[str, int] or None
If present, returns a dictionary mapping edge type to edge type
id.
"""
return self._c_csc_graph.edge_type_to_id()
@edge_type_to_id.setter
def edge_type_to_id(
self, edge_type_to_id: Optional[Dict[str, int]]
) -> None:
"""Sets the edge type to id dictionary if present."""
self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)
@property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary.
......
......@@ -376,9 +376,11 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert graph.node_type_offset is None and graph2.node_type_offset is None
assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf(
......@@ -425,8 +427,12 @@ def test_pickle_hetero_graph(
assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
assert graph.node_type_to_id.keys() == graph2.node_type_to_id.keys()
for i in graph.node_type_to_id.keys():
assert graph.node_type_to_id[i] == graph2.node_type_to_id[i]
assert graph.edge_type_to_id.keys() == graph2.edge_type_to_id.keys()
for i in graph.edge_type_to_id.keys():
assert graph.edge_type_to_id[i] == graph2.edge_type_to_id[i]
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for i in graph.edge_attributes.keys():
assert torch.equal(graph.edge_attributes[i], graph2.edge_attributes[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment