"src/array/vscode:/vscode.git/clone" did not exist on "00c27cb2645b70a7b28f2d7f779fe8d254f8d7ab"
Unverified Commit 95cf6924 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Return etypes in the neighbor sampling (#5773)

parent ae97049e
......@@ -169,11 +169,15 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node);
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
if (type_per_edge_.has_value())
subgraph_type_per_edge =
torch::index_select(type_per_edge_.value(), 0, picked_eids);
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, torch::nullopt);
subgraph_reverse_edge_ids, subgraph_type_per_edge);
}
c10::intrusive_ptr<CSCSamplingGraph>
......
......@@ -232,6 +232,30 @@ class CSCSamplingGraph:
Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is
typically used when edge features are required
Returns
-------
SampledSubgraph
The sampled subgraph.
Examples
--------
>>> indptr = torch.LongTensor([0, 3, 5, 7])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1])
>>> type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1])
>>> graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
>>> nodes = torch.LongTensor([1, 2])
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True)
>>> print(subgraph.indptr)
tensor([0, 2, 4])
>>> print(subgraph.indices)
tensor([2, 3, 0, 1])
>>> print(subgraph.reverse_column_node_ids)
tensor([1, 2])
>>> print(subgraph.reverse_edge_ids)
tensor([3, 4, 5, 6])
>>> print(subgraph.type_per_edge)
tensor([0, 1, 0, 1])
"""
# Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
......
......@@ -411,7 +411,7 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
fanouts = torch.tensor([2, 2, 3])
fanouts = torch.tensor([2, 2])
subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True)
# Verify in subgraph.
......@@ -424,8 +424,10 @@ def test_sample_neighbors():
assert torch.equal(
subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
)
assert torch.equal(
subgraph.type_per_edge, torch.LongTensor([0, 1, 0, 1, 0, 0, 1])
)
assert subgraph.reverse_row_node_ids is None
assert subgraph.type_per_edge is None
@unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment