Unverified Commit ae97049e authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add return_eid for neighbor sampling (#5772)

parent 5f490d19
......@@ -135,13 +135,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
*
* @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information.
*/
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const;
bool replace, bool return_eids) const;
/**
* @brief Copy the graph to shared memory.
......
......@@ -123,7 +123,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const {
bool replace, bool return_eids) const {
const int64_t num_nodes = nodes.size(0);
// If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types.
......@@ -169,10 +169,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node);
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt, torch::nullopt,
torch::nullopt);
subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, torch::nullopt);
}
c10::intrusive_ptr<CSCSamplingGraph>
......
......@@ -199,6 +199,7 @@ class CSCSamplingGraph:
nodes: torch.Tensor,
fanouts: torch.Tensor,
replace: bool = False,
return_eids: bool = False,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
......@@ -227,6 +228,10 @@ class CSCSamplingGraph:
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
return_eids: bool
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
......@@ -241,7 +246,7 @@ class CSCSamplingGraph:
), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0."
return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace
nodes, fanouts.tolist(), replace, return_eids
)
def copy_to_shared_memory(self, shared_memory_name: str):
......
......@@ -412,7 +412,7 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
fanouts = torch.tensor([2, 2, 3])
subgraph = graph.sample_neighbors(nodes, fanouts)
subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True)
# Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
......@@ -421,8 +421,10 @@ def test_sample_neighbors():
torch.sort(torch.LongTensor([2, 3, 1, 2, 0, 3, 4]))[0],
)
assert torch.equal(subgraph.reverse_column_node_ids, nodes)
assert torch.equal(
subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None
assert subgraph.type_per_edge is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment