"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "168a88e57070871eef5a9fcdad3ed1a4d708d7bd"
Unverified Commit c9c165f7 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add fanout for neighbor sampling (#5768)

[Graphbolt] Add fanout for sampling
parent a99095e7
...@@ -120,12 +120,16 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -120,12 +120,16 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* subgraph. * subgraph.
* *
* @param nodes The nodes from which to sample neighbors. * @param nodes The nodes from which to sample neighbors.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* *
* @return An intrusive pointer to a SampledSubgraph object containing the * @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information. * sampled graph's information.
*/ */
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes) const; const torch::Tensor& nodes, int64_t fanout) const;
/** /**
* @brief Copy the graph to shared memory. * @brief Copy the graph to shared memory.
...@@ -199,6 +203,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -199,6 +203,24 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
SharedMemoryPtr tensor_meta_shm_, tensor_data_shm_; SharedMemoryPtr tensor_meta_shm_, tensor_data_shm_;
}; };
/**
* @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors.
*
* @param offset The starting edge ID for the connected neighbors of the sampled
* node.
* @param num_neighbors The number of neighbors to pick.
* @param fanout The number of edges to be sampled for each node. It should be
* >= 0 or -1. If -1 is given, all neighbors will be selected. Otherwise, it
* will pick the minimum number of neighbors between the fanout value and the
* total number of neighbors.
* @param options Tensor options specifying the desired data type of the result.
* @return A tensor containing the picked neighbors.
*/
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options);
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
......
...@@ -122,7 +122,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -122,7 +122,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
} }
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes) const { const torch::Tensor& nodes, int64_t fanout) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes); std::vector<torch::Tensor> picked_neighbors_per_node(num_nodes);
...@@ -148,8 +148,9 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -148,8 +148,9 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
} }
picked_neighbors_per_node[i] = picked_neighbors_per_node[i] =
torch::arange(offset, offset + num_neighbors); Pick(offset, num_neighbors, fanout, indptr_.options());
num_picked_neighbors_per_node[i + 1] = num_neighbors; num_picked_neighbors_per_node[i + 1] =
picked_neighbors_per_node[i].size(0);
} }
}); // End of the thread. }); // End of the thread.
...@@ -195,5 +196,18 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory( ...@@ -195,5 +196,18 @@ c10::intrusive_ptr<CSCSamplingGraph> CSCSamplingGraph::LoadFromSharedMemory(
return BuildGraphFromSharedMemoryTensors(std::move(shared_memory_tensors)); return BuildGraphFromSharedMemoryTensors(std::move(shared_memory_tensors));
} }
torch::Tensor Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options) {
torch::Tensor picked_neighbors;
if ((fanout == -1) || (num_neighbors <= fanout)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options);
} else {
picked_neighbors = torch::randperm(num_neighbors) + offset;
picked_neighbors = picked_neighbors.slice(0, 0, fanout);
}
return picked_neighbors;
}
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
...@@ -194,7 +194,9 @@ class CSCSamplingGraph: ...@@ -194,7 +194,9 @@ class CSCSamplingGraph:
), "Nodes cannot have duplicate values." ), "Nodes cannot have duplicate values."
return self._c_csc_graph.in_subgraph(nodes) return self._c_csc_graph.in_subgraph(nodes)
def sample_neighbors(self, nodes: torch.Tensor) -> torch.ScriptObject: def sample_neighbors(
self, nodes: torch.Tensor, fanout: int
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -202,10 +204,16 @@ class CSCSamplingGraph: ...@@ -202,10 +204,16 @@ class CSCSamplingGraph:
---------- ----------
nodes: torch.Tensor nodes: torch.Tensor
IDs of the given seed nodes. IDs of the given seed nodes.
fanout: int
The number of edges to be sampled for each node. It should be
>= 0 or -1. If -1 is given, all neighbors will be selected.
Otherwise, it will pick the minimum number of neighbors between
the fanout value and the total number of neighbors.
""" """
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
return self._c_csc_graph.sample_neighbors(nodes) assert fanout >= 0 or fanout == -1, "Fanout shoud have value >= 0 or -1"
return self._c_csc_graph.sample_neighbors(nodes, fanout)
def copy_to_shared_memory(self, shared_memory_name: str): def copy_to_shared_memory(self, shared_memory_name: str):
"""Copy the graph to shared memory. """Copy the graph to shared memory.
......
...@@ -410,7 +410,8 @@ def test_sample_neighbors(): ...@@ -410,7 +410,8 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes) fanout = -1
subgraph = graph.sample_neighbors(nodes, fanout)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7])) assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
...@@ -423,6 +424,42 @@ def test_sample_neighbors(): ...@@ -423,6 +424,42 @@ def test_sample_neighbors():
assert subgraph.type_per_edge is None assert subgraph.type_per_edge is None
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize(
"fanout, expected_sampled_num",
[(0, 0), (1, 3), (2, 6), (3, 7), (4, 7), (-1, 7)],
)
def test_sample_neighbors_fanout(fanout, expected_sampled_num):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
0 1 0 1 0
0 1 0 0 1
1 0 0 0 1
"""
# Initialize data.
num_nodes = 5
num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
assert indptr[-1] == num_edges
assert indptr[-1] == len(indices)
# Construct CSCSamplingGraph.
graph = gb.from_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanout)
# Verify in subgraph.
sampled_num = subgraph.indices.size(0)
assert sampled_num == expected_sampled_num
def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor): def check_tensors_on_the_same_shared_memory(t1: torch.Tensor, t2: torch.Tensor):
"""Check if two tensors are on the same shared memory. """Check if two tensors are on the same shared memory.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment