"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7ba9cbc63e337f207c0707d1018d0d4e1c83ca51"
Unverified Commit 2c325b2d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add node/edge_type_to_id into pickle (#6701)

parent cbb6f502
...@@ -232,6 +232,12 @@ void FusedCSCSamplingGraph::SetState( ...@@ -232,6 +232,12 @@ void FusedCSCSamplingGraph::SetState(
if (independent_tensors.find("type_per_edge") != independent_tensors.end()) { if (independent_tensors.find("type_per_edge") != independent_tensors.end()) {
type_per_edge_ = independent_tensors.at("type_per_edge"); type_per_edge_ = independent_tensors.at("type_per_edge");
} }
if (state.find("node_type_to_id") != state.end()) {
node_type_to_id_ = DetensorizeDict(state.at("node_type_to_id"));
}
if (state.find("edge_type_to_id") != state.end()) {
edge_type_to_id_ = DetensorizeDict(state.at("edge_type_to_id"));
}
if (state.find("edge_attributes") != state.end()) { if (state.find("edge_attributes") != state.end()) {
edge_attributes_ = state.at("edge_attributes"); edge_attributes_ = state.at("edge_attributes");
} }
...@@ -256,6 +262,12 @@ FusedCSCSamplingGraph::GetState() const { ...@@ -256,6 +262,12 @@ FusedCSCSamplingGraph::GetState() const {
independent_tensors.insert("type_per_edge", type_per_edge_.value()); independent_tensors.insert("type_per_edge", type_per_edge_.value());
} }
state.insert("independent_tensors", independent_tensors); state.insert("independent_tensors", independent_tensors);
if (node_type_to_id_.has_value()) {
state.insert("node_type_to_id", TensorizeDict(node_type_to_id_).value());
}
if (edge_type_to_id_.has_value()) {
state.insert("edge_type_to_id", TensorizeDict(edge_type_to_id_).value());
}
if (edge_attributes_.has_value()) { if (edge_attributes_.has_value()) {
state.insert("edge_attributes", edge_attributes_.value()); state.insert("edge_attributes", edge_attributes_.value());
} }
......
...@@ -35,11 +35,15 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -35,11 +35,15 @@ TORCH_LIBRARY(graphbolt, m) {
.def("indices", &FusedCSCSamplingGraph::Indices) .def("indices", &FusedCSCSamplingGraph::Indices)
.def("node_type_offset", &FusedCSCSamplingGraph::NodeTypeOffset) .def("node_type_offset", &FusedCSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge) .def("type_per_edge", &FusedCSCSamplingGraph::TypePerEdge)
.def("node_type_to_id", &FusedCSCSamplingGraph::NodeTypeToID)
.def("edge_type_to_id", &FusedCSCSamplingGraph::EdgeTypeToID)
.def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes) .def("edge_attributes", &FusedCSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr) .def("set_csc_indptr", &FusedCSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &FusedCSCSamplingGraph::SetIndices) .def("set_indices", &FusedCSCSamplingGraph::SetIndices)
.def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset) .def("set_node_type_offset", &FusedCSCSamplingGraph::SetNodeTypeOffset)
.def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge) .def("set_type_per_edge", &FusedCSCSamplingGraph::SetTypePerEdge)
.def("set_node_type_to_id", &FusedCSCSamplingGraph::SetNodeTypeToID)
.def("set_edge_type_to_id", &FusedCSCSamplingGraph::SetEdgeTypeToID)
.def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes) .def("set_edge_attributes", &FusedCSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph) .def("in_subgraph", &FusedCSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors) .def("sample_neighbors", &FusedCSCSamplingGraph::SampleNeighbors)
......
...@@ -255,6 +255,44 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -255,6 +255,44 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge type tensor if present.""" """Sets the edge type tensor if present."""
self._c_csc_graph.set_type_per_edge(type_per_edge) self._c_csc_graph.set_type_per_edge(type_per_edge)
@property
def node_type_to_id(self) -> Optional[Dict[str, int]]:
"""Returns the node type to id dictionary if present.
Returns
-------
Dict[str, int] or None
If present, returns a dictionary mapping node type to node type
id.
"""
return self._c_csc_graph.node_type_to_id()
@node_type_to_id.setter
def node_type_to_id(
self, node_type_to_id: Optional[Dict[str, int]]
) -> None:
"""Sets the node type to id dictionary if present."""
self._c_csc_graph.set_node_type_to_id(node_type_to_id)
@property
def edge_type_to_id(self) -> Optional[Dict[str, int]]:
"""Returns the edge type to id dictionary if present.
Returns
-------
Dict[str, int] or None
If present, returns a dictionary mapping edge type to edge type
id.
"""
return self._c_csc_graph.edge_type_to_id()
@edge_type_to_id.setter
def edge_type_to_id(
self, edge_type_to_id: Optional[Dict[str, int]]
) -> None:
"""Sets the edge type to id dictionary if present."""
self._c_csc_graph.set_edge_type_to_id(edge_type_to_id)
@property @property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]: def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary. """Returns the edge attributes dictionary.
......
...@@ -376,9 +376,11 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges): ...@@ -376,9 +376,11 @@ def test_pickle_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(graph.csc_indptr, graph2.csc_indptr) assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert graph.node_type_offset is None and graph2.node_type_offset is None assert graph.node_type_offset is None and graph2.node_type_offset is None
assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf( @unittest.skipIf(
...@@ -425,8 +427,12 @@ def test_pickle_hetero_graph( ...@@ -425,8 +427,12 @@ def test_pickle_hetero_graph(
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset) assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id assert graph.node_type_to_id.keys() == graph2.node_type_to_id.keys()
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id for i in graph.node_type_to_id.keys():
assert graph.node_type_to_id[i] == graph2.node_type_to_id[i]
assert graph.edge_type_to_id.keys() == graph2.edge_type_to_id.keys()
for i in graph.edge_type_to_id.keys():
assert graph.edge_type_to_id[i] == graph2.edge_type_to_id[i]
assert graph.edge_attributes.keys() == graph2.edge_attributes.keys() assert graph.edge_attributes.keys() == graph2.edge_attributes.keys()
for i in graph.edge_attributes.keys(): for i in graph.edge_attributes.keys():
assert torch.equal(graph.edge_attributes[i], graph2.edge_attributes[i]) assert torch.equal(graph.edge_attributes[i], graph2.edge_attributes[i])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment