Unverified Commit 95cf6924 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Return etypes in the neighbor sampling (#5773)

parent ae97049e
...@@ -169,11 +169,15 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors( ...@@ -169,11 +169,15 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighbors(
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node); torch::Tensor picked_eids = torch::cat(picked_neighbors_per_node);
torch::Tensor subgraph_indices = torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids); torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
if (type_per_edge_.has_value())
subgraph_type_per_edge =
torch::index_select(type_per_edge_.value(), 0, picked_eids);
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>( return c10::make_intrusive<SampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt, subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, torch::nullopt); subgraph_reverse_edge_ids, subgraph_type_per_edge);
} }
c10::intrusive_ptr<CSCSamplingGraph> c10::intrusive_ptr<CSCSamplingGraph>
......
...@@ -232,6 +232,30 @@ class CSCSamplingGraph: ...@@ -232,6 +232,30 @@ class CSCSamplingGraph:
Boolean indicating whether the edge IDs of sampled edges, Boolean indicating whether the edge IDs of sampled edges,
represented as a 1D tensor, should be returned. This is represented as a 1D tensor, should be returned. This is
typically used when edge features are required typically used when edge features are required
Returns
-------
SampledSubgraph
The sampled subgraph.
Examples
--------
>>> indptr = torch.LongTensor([0, 3, 5, 7])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1])
>>> type_per_edge = torch.LongTensor([0, 0, 1, 0, 1, 0, 1])
>>> graph = gb.from_csc(indptr, indices, type_per_edge=type_per_edge)
>>> nodes = torch.LongTensor([1, 2])
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True)
>>> print(subgraph.indptr)
tensor([0, 2, 4])
>>> print(subgraph.indices)
tensor([2, 3, 0, 1])
>>> print(subgraph.reverse_column_node_ids)
tensor([1, 2])
>>> print(subgraph.reverse_edge_ids)
tensor([3, 4, 5, 6])
>>> print(subgraph.type_per_edge)
tensor([0, 1, 0, 1])
""" """
# Ensure nodes is 1-D tensor. # Ensure nodes is 1-D tensor.
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
......
...@@ -411,7 +411,7 @@ def test_sample_neighbors(): ...@@ -411,7 +411,7 @@ def test_sample_neighbors():
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.LongTensor([1, 3, 4])
fanouts = torch.tensor([2, 2, 3]) fanouts = torch.tensor([2, 2])
subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True) subgraph = graph.sample_neighbors(nodes, fanouts, return_eids=True)
# Verify in subgraph. # Verify in subgraph.
...@@ -424,8 +424,10 @@ def test_sample_neighbors(): ...@@ -424,8 +424,10 @@ def test_sample_neighbors():
assert torch.equal( assert torch.equal(
subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11]) subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
) )
assert torch.equal(
subgraph.type_per_edge, torch.LongTensor([0, 1, 0, 1, 0, 0, 1])
)
assert subgraph.reverse_row_node_ids is None assert subgraph.reverse_row_node_ids is None
assert subgraph.type_per_edge is None
@unittest.skipIf( @unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment