Unverified Commit ae97049e authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[GraphBolt] Add return_eid for neighbor sampling (#5772)

parent 5f490d19
...@@ -135,13 +135,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder { ...@@ -135,13 +135,15 @@ class CSCSamplingGraph : public torch::CustomClassHolder {
* @param replace Boolean indicating whether the sample is preformed with or * @param replace Boolean indicating whether the sample is preformed with or
* without replacement. If True, a value can be selected multiple times. * without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once. * Otherwise, each value can be selected only once.
* @param return_eids Boolean indicating whether edge IDs need to be returned,
* typically used when edge features are required.
* *
* @return An intrusive pointer to a SampledSubgraph object containing the * @return An intrusive pointer to a SampledSubgraph object containing the
* sampled graph's information. * sampled graph's information.
*/ */
c10::intrusive_ptr<SampledSubgraph> SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const; bool replace, bool return_eids) const;
/** /**
* @brief Copy the graph to shared memory. * @brief Copy the graph to shared memory.
......
...@@ -123,7 +123,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph( ...@@ -123,7 +123,7 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::InSubgraph(
c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace) const { bool replace, bool return_eids) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
// If true, perform sampling for each edge type of each node, otherwise just // If true, perform sampling for each edge type of each node, otherwise just
// sample once for each node with no regard of edge types. // sample once for each node with no regard of edge types.
...@@ -169,10 +169,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -169,10 +169,11 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node); torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node);
torch::Tensor subgraph_indices = torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids); torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>( return c10::make_intrusive<SampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt, torch::nullopt, subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
torch::nullopt); subgraph_reverse_edge_ids, torch::nullopt);
} }
c10::intrusive_ptr<CSCSamplingGraph> c10::intrusive_ptr<CSCSamplingGraph>
......
...@@ -199,6 +199,7 @@ class CSCSamplingGraph: ...@@ -199,6 +199,7 @@ class CSCSamplingGraph:
nodes: torch.Tensor, nodes: torch.Tensor,
fanouts: torch.Tensor, fanouts: torch.Tensor,
replace: bool = False, replace: bool = False,
return_eids: bool = False,
) -> torch.ScriptObject: ) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced """Sample neighboring edges of the given nodes and return the induced
subgraph. subgraph.
...@@ -227,6 +228,10 @@ class CSCSamplingGraph: ...@@ -227,6 +228,10 @@ class CSCSamplingGraph:
Boolean indicating whether the sample is preformed with or Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once. times. Otherwise, each value can be selected only once.
return_eids: bool
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required
""" """
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
...@@ -241,7 +246,7 @@ class CSCSamplingGraph: ...@@ -241,7 +246,7 @@ class CSCSamplingGraph:
), "Fanouts should consist of values that are either -1 or \ ), "Fanouts should consist of values that are either -1 or \
greater than or equal to 0." greater than or equal to 0."
return self._c_csc_graph.sample_neighbors( return self._c_csc_graph.sample_neighbors(
nodes, fanouts.tolist(), replace nodes, fanouts.tolist(), replace, return_eids
) )
def copy_to_shared_memory(self, shared_memory_name: str): def copy_to_shared_memory(self, shared_memory_name: str):
......
...@@ -412,7 +412,7 @@ def test_sample_neighbors(): ...@@ -412,7 +412,7 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
fanouts = torch.tensor([2, 2, 3]) fanouts = torch.tensor([2, 2, 3])
subgraph = graph.sample_neighbors(nodes, fanouts) subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True)
# Verify in subgraph. # Verify in subgraph.
assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7])) assert torch.equal(subgraph.indptr, torch.LongTensor([0, 2, 4, 7]))
...@@ -421,8 +421,10 @@ def test_sample_neighbors(): ...@@ -421,8 +421,10 @@ def test_sample_neighbors():
torch.sort(torch.LongTensor([2, 3, 1, 2, 0, 3, 4]))[0], torch.sort(torch.LongTensor([2, 3, 1, 2, 0, 3, 4]))[0],
) )
assert torch.equal(subgraph.reverse_column_node_ids, nodes) assert torch.equal(subgraph.reverse_column_node_ids, nodes)
assert torch.equal(
subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert subgraph.reverse_row_node_ids is None assert subgraph.reverse_row_node_ids is None
assert subgraph.reverse_edge_ids is None
assert subgraph.type_per_edge is None assert subgraph.type_per_edge is None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment