Unverified Commit 308f8ca3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable more dtypes for sample_neighbors (#6523)

parent ba2ca4be
...@@ -335,13 +335,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -335,13 +335,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
auto num_picked_neighbors_data_ptr = auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<scalar_t>(); num_picked_neighbors_per_node.data_ptr<scalar_t>();
num_picked_neighbors_data_ptr[0] = 0; num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<int64_t>();
// Step 1. Calculate pick number of each node. // Step 1. Calculate pick number of each node.
torch::parallel_for( torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i]; const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK( TORCH_CHECK(
nid >= 0 && nid < NumNodes(), nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the " "The seed nodes' IDs should fall within the range of the "
...@@ -356,7 +355,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -356,7 +355,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// Step 2. Calculate prefix sum to get total length and offsets of each // Step 2. Calculate prefix sum to get total length and offsets of each
// node. It's also the indptr of the generated subgraph. // node. It's also the indptr of the generated subgraph.
subgraph_indptr = torch::cumsum(num_picked_neighbors_per_node, 0); subgraph_indptr =
num_picked_neighbors_per_node.cumsum(0, indptr_.scalar_type());
// Step 3. Allocate the tensor for picked neighbors. // Step 3. Allocate the tensor for picked neighbors.
const auto total_length = const auto total_length =
...@@ -374,7 +374,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -374,7 +374,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::parallel_for( torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i]; const auto nid = nodes[i].item<int64_t>();
const auto offset = indptr_data[nid]; const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number = num_picked_neighbors_data_ptr[i + 1]; const auto picked_number = num_picked_neighbors_data_ptr[i + 1];
......
...@@ -604,7 +604,10 @@ def test_in_subgraph_heterogeneous(): ...@@ -604,7 +604,10 @@ def test_in_subgraph_heterogeneous():
F._default_context_str == "gpu", F._default_context_str == "gpu",
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
def test_sample_neighbors_homo(): @pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -615,17 +618,21 @@ def test_sample_neighbors_homo(): ...@@ -615,17 +618,21 @@ def test_sample_neighbors_homo():
# Initialize data. # Initialize data.
total_num_nodes = 5 total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices) assert indptr[-1] == len(indices)
assert len(indptr) == total_num_nodes + 1
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices) graph = gb.from_fused_csc(indptr, indices)
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4]) nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([2])) sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([2]))
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.node_pairs[0].size(0)
...@@ -640,7 +647,9 @@ def test_sample_neighbors_homo(): ...@@ -640,7 +647,9 @@ def test_sample_neighbors_homo():
reason="Graph is CPU only at present.", reason="Graph is CPU only at present.",
) )
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor): @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
"""Original graph in COO: """Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
...@@ -656,8 +665,8 @@ def test_sample_neighbors_hetero(labor): ...@@ -656,8 +665,8 @@ def test_sample_neighbors_hetero(labor):
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5 total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5]) node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges assert indptr[-1] == total_num_edges
...@@ -673,7 +682,10 @@ def test_sample_neighbors_hetero(labor): ...@@ -673,7 +682,10 @@ def test_sample_neighbors_hetero(labor):
) )
# Sample on both node types. # Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
fanouts = torch.tensor([-1, -1]) fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment