Unverified Commit ee00729b authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Define API for Neighbor sampling (#5766)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-25-242.ap-northeast-1.compute.internal>
parent f0b61222
...@@ -115,6 +115,18 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -115,6 +115,18 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
c10::intrusive_ptr<SampledSubgraph> InSubgraph( c10::intrusive_ptr<SampledSubgraph> InSubgraph(
const torch::Tensor& nodes) const; const torch::Tensor& nodes) const;
/**
* @brief Sample neighboring edges of the given nodes and return the induced
* subgraph.
*
* @param nodes The nodes from which to sample neighbors.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes) const;
/** /**
* @brief Copy the graph to shared memory. * @brief Copy the graph to shared memory.
* @param shared_memory_name The name of the shared memory. * @param shared_memory_name The name of the shared memory.
......
...@@ -121,6 +121,15 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -121,6 +121,15 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
: torch::nullopt); : torch::nullopt);
} }
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes) const {
// TODO(#5692): implement this.
return c10::make_intrusive<SampledSubgraph>(
torch::zeros({nodes.size(0) + 1}, indptr_.options()),
torch::zeros({1}, indptr_.options()), nodes, torch::nullopt,
torch::nullopt, torch::nullopt);
}
c10::intrusive_ptr<CSCSamplingGraph> c10::intrusive_ptr<CSCSamplingGraph>
CSCSamplingGraph::BuildGraphFromSharedMemoryTensors( CSCSamplingGraph::BuildGraphFromSharedMemoryTensors(
std::tuple< std::tuple<
......
...@@ -29,6 +29,7 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -29,6 +29,7 @@ TORCH_LIBRARY(graphbolt, m) {
.def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset) .def("node_type_offset", &CSCSamplingGraph::NodeTypeOffset)
.def("type_per_edge", &CSCSamplingGraph::TypePerEdge) .def("type_per_edge", &CSCSamplingGraph::TypePerEdge)
.def("in_subgraph", &CSCSamplingGraph::InSubgraph) .def("in_subgraph", &CSCSamplingGraph::InSubgraph)
.def("sample_neighbors", &CSCSamplingGraph::SampleNeighbors)
.def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory); .def("copy_to_shared_memory", &CSCSamplingGraph::CopyToSharedMemory);
m.def("from_csc", &CSCSamplingGraph::FromCSC); m.def("from_csc", &CSCSamplingGraph::FromCSC);
m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph); m.def("load_csc_sampling_graph", &LoadCSCSamplingGraph);
......
...@@ -194,6 +194,19 @@ class CSCSamplingGraph: ...@@ -194,6 +194,19 @@ class CSCSamplingGraph:
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
return self._c_csc_graph.in_subgraph(nodes) return self._c_csc_graph.in_subgraph(nodes)
def sample_neighbors(self, nodes: torch.Tensor) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
Parameters
----------
nodes: torch.Tensor
IDs of the given seed nodes.
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
return self._c_csc_graph.sample_neighbors(nodes)
def copy_to_shared_memory(self, shared_memory_name: str): def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory. """Copy the graph to shared memory.
......
...@@ -385,6 +385,42 @@ def test_in_subgraph_heterogeneous(): ...@@ -385,6 +385,42 @@ def test_in_subgraph_heterogeneous():
) )
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_sample_neighbors():
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes)
# Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 0, 0, 0]))
assert torch.equal(subgraph.indices, torch.LongTensor([0]))
assert torch.equal(subgraph.reverse_column_node_ids, nodes)
assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None
assert subgraph.type_per_edge is None
def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"""Check if two tensors are on the same shared memory. """Check if two tensors are on the same shared memory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment