Unverified Commit 7439b7e7 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Add to function for CSCSamplingGraph (#6465)

parent ea58090e
...@@ -111,6 +111,30 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -111,6 +111,30 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
return edge_attributes_; return edge_attributes_;
} }
/** @brief Set the csc index pointer tensor. */
inline void SetCSCIndptr(const torch::Tensor& indptr) { indptr_ = indptr; }
/** @brief Set the index tensor. */
inline void SetIndices(const torch::Tensor& indices) { indices_ = indices; }
/** @brief Set the node type offset tensor for a heterogeneous graph. */
inline void SetNodeTypeOffset(
const torch::optional<torch::Tensor>& node_type_offset) {
node_type_offset_ = node_type_offset;
}
/** @brief Set the edge type tensor for a heterogeneous graph. */
inline void SetTypePerEdge(
const torch::optional<torch::Tensor>& type_per_edge) {
type_per_edge_ = type_per_edge;
}
/** @brief Set the edge attributes dictionary. */
inline void SetEdgeAttributes(
const torch::optional<EdgeAttrMap>& edge_attributes) {
edge_attributes_ = edge_attributes;
}
/** /**
* @brief Magic number to indicate graph version in serialize/deserialize * @brief Magic number to indicate graph version in serialize/deserialize
* stage. * stage.
......
...@@ -32,6 +32,11 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -32,6 +32,11 @@ TORCH_LIBRARY(graphbolt, m) {
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset) .def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge) .def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
.def("edge_attributes", &CSCSamplingGraph::EdgeAttributes) .def("edge_attributes", &CSCSamplingGraph::EdgeAttributes)
.def("set_csc_indptr", &CSCSamplingGraph::SetCSCIndptr)
.def("set_indices", &CSCSamplingGraph::SetIndices)
.def("set_node_type_offset", &CSCSamplingGraph::SetNodeTypeOffset)
.def("set_type_per_edge", &CSCSamplingGraph::SetTypePerEdge)
.def("set_edge_attributes", &CSCSamplingGraph::SetEdgeAttributes)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph) .def("in_subgraph", &CSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors) .def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors)
.def( .def(
......
...@@ -8,6 +8,8 @@ from typing import Dict, Optional, Union ...@@ -8,6 +8,8 @@ from typing import Dict, Optional, Union
import torch import torch
from dgl.utils import recursive_apply
from ...base import EID, ETYPE from ...base import EID, ETYPE
from ...convert import to_homogeneous from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
...@@ -181,6 +183,11 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -181,6 +183,11 @@ class CSCSamplingGraph(SamplingGraph):
""" """
return self._c_csc_graph.csc_indptr() return self._c_csc_graph.csc_indptr()
@csc_indptr.setter
def csc_indptr(self, csc_indptr: torch.tensor) -> None:
"""Sets the indices pointer in the CSC graph."""
self._c_csc_graph.set_csc_indptr(csc_indptr)
@property @property
def indices(self) -> torch.tensor: def indices(self) -> torch.tensor:
"""Returns the indices in the CSC graph. """Returns the indices in the CSC graph.
...@@ -198,6 +205,11 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -198,6 +205,11 @@ class CSCSamplingGraph(SamplingGraph):
""" """
return self._c_csc_graph.indices() return self._c_csc_graph.indices()
@indices.setter
def indices(self, indices: torch.tensor) -> None:
"""Sets the indices in the CSC graph."""
self._c_csc_graph.set_indices(indices)
@property @property
def node_type_offset(self) -> Optional[torch.Tensor]: def node_type_offset(self) -> Optional[torch.Tensor]:
"""Returns the node type offset tensor if present. """Returns the node type offset tensor if present.
...@@ -215,6 +227,13 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -215,6 +227,13 @@ class CSCSamplingGraph(SamplingGraph):
""" """
return self._c_csc_graph.node_type_offset() return self._c_csc_graph.node_type_offset()
@node_type_offset.setter
def node_type_offset(
self, node_type_offset: Optional[torch.Tensor]
) -> None:
"""Sets the node type offset tensor if present."""
self._c_csc_graph.set_node_type_offset(node_type_offset)
@property @property
def type_per_edge(self) -> Optional[torch.Tensor]: def type_per_edge(self) -> Optional[torch.Tensor]:
"""Returns the edge type tensor if present. """Returns the edge type tensor if present.
...@@ -227,6 +246,11 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -227,6 +246,11 @@ class CSCSamplingGraph(SamplingGraph):
""" """
return self._c_csc_graph.type_per_edge() return self._c_csc_graph.type_per_edge()
@type_per_edge.setter
def type_per_edge(self, type_per_edge: Optional[torch.Tensor]) -> None:
"""Sets the edge type tensor if present."""
self._c_csc_graph.set_type_per_edge(type_per_edge)
@property @property
def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]: def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
"""Returns the edge attributes dictionary. """Returns the edge attributes dictionary.
...@@ -241,6 +265,13 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -241,6 +265,13 @@ class CSCSamplingGraph(SamplingGraph):
""" """
return self._c_csc_graph.edge_attributes() return self._c_csc_graph.edge_attributes()
@edge_attributes.setter
def edge_attributes(
self, edge_attributes: Optional[Dict[str, torch.Tensor]]
) -> None:
"""Sets the edge attributes dictionary."""
self._c_csc_graph.set_edge_attributes(edge_attributes)
@property @property
def metadata(self) -> Optional[GraphMetadata]: def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph. """Returns the metadata of the graph.
...@@ -674,6 +705,28 @@ class CSCSamplingGraph(SamplingGraph): ...@@ -674,6 +705,28 @@ class CSCSamplingGraph(SamplingGraph):
self._metadata, self._metadata,
) )
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `CSCSamplingGraph` to the specified device."""
def _to(x, device):
return x.to(device) if hasattr(x, "to") else x
self.csc_indptr = recursive_apply(
self.csc_indptr, lambda x: _to(x, device)
)
self.indices = recursive_apply(self.indices, lambda x: _to(x, device))
self.node_type_offset = recursive_apply(
self.node_type_offset, lambda x: _to(x, device)
)
self.type_per_edge = recursive_apply(
self.type_per_edge, lambda x: _to(x, device)
)
self.edge_attributes = recursive_apply(
self.edge_attributes, lambda x: _to(x, device)
)
return self
def from_csc( def from_csc(
csc_indptr: torch.Tensor, csc_indptr: torch.Tensor,
......
...@@ -1604,3 +1604,56 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -1604,3 +1604,56 @@ def test_sample_neighbors_hetero_pick_number(
else: else:
# Etype 2: 0 valid neighbors. # Etype 2: 0 valid neighbors.
assert pairs[0].size(0) == 0 assert pairs[0].size(0) == 0
@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
)
def test_csc_sampling_graph_to_device():
# Initialize data.
total_num_nodes = 10
total_num_edges = 9
ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3}
etypes = {
"N0:R0:N1": 0,
"N0:R1:N2": 1,
"N0:R2:N3": 2,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
type_per_edge = torch.LongTensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert node_type_offset[-1] == total_num_nodes
assert all(type_per_edge < len(etypes))
edge_attributes = {
"mask": torch.BoolTensor([1, 1, 0, 1, 1, 1, 0, 0, 0]),
"all": torch.BoolTensor([1, 1, 1, 1, 1, 1, 1, 1, 1]),
"zero": torch.BoolTensor([0, 0, 0, 0, 0, 0, 0, 0, 0]),
}
# Construct CSCSamplingGraph.
graph = gb.from_csc(
indptr,
indices,
edge_attributes=edge_attributes,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
)
# Copy to device.
graph = graph.to("cuda")
# Check.
assert graph.csc_indptr.device.type == "cuda"
assert graph.indices.device.type == "cuda"
assert graph.node_type_offset.device.type == "cuda"
assert graph.type_per_edge.device.type == "cuda"
assert graph.csc_indptr.device.type == "cuda"
for key in graph.edge_attributes:
assert graph.edge_attributes[key].device.type == "cuda"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment