"vscode:/vscode.git/clone" did not exist on "d934d3d79519653ad4db5b63a7665c1536b4187c"
Unverified Commit 308f8ca3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] enable more dtypes for sample_neighbors (#6523)

parent ba2ca4be
......@@ -335,13 +335,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<scalar_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<int64_t>();
// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the "
......@@ -356,7 +355,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// Step 2. Calculate prefix sum to get total length and offsets of each
// node. It's also the indptr of the generated subgraph.
subgraph_indptr = torch::cumsum(num_picked_neighbors_per_node, 0);
subgraph_indptr =
num_picked_neighbors_per_node.cumsum(0, indptr_.scalar_type());
// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
......@@ -374,7 +374,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
const auto nid = nodes[i].item<int64_t>();
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number = num_picked_neighbors_data_ptr[i + 1];
......
......@@ -604,7 +604,10 @@ def test_in_subgraph_heterogeneous():
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
def test_sample_neighbors_homo():
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
......@@ -615,17 +618,21 @@ def test_sample_neighbors_homo():
# Initialize data.
total_num_nodes = 5
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)
assert len(indptr) == total_num_nodes + 1
# Construct FusedCSCSamplingGraph.
graph = gb.from_fused_csc(indptr, indices)
# Generate subgraph via sample neighbors.
nodes = torch.LongTensor([1, 3, 4])
subgraph = graph.sample_neighbors(nodes, fanouts=torch.LongTensor([2]))
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype)
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts=torch.LongTensor([2]))
# Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0)
......@@ -640,7 +647,9 @@ def test_sample_neighbors_homo():
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor):
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
......@@ -656,8 +665,8 @@ def test_sample_neighbors_hetero(labor):
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
assert indptr[-1] == total_num_edges
......@@ -673,7 +682,10 @@ def test_sample_neighbors_hetero(labor):
)
# Sample on both node types.
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
nodes = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment