Unverified Commit f86212ed authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Enable tests for weighted sampling (#6919)

parent 47a1d6a8
...@@ -43,7 +43,7 @@ template < ...@@ -43,7 +43,7 @@ template <
__global__ void _ComputeRandoms( __global__ void _ComputeRandoms(
const int64_t num_edges, const indptr_t* const sliced_indptr, const int64_t num_edges, const indptr_t* const sliced_indptr,
const indptr_t* const sub_indptr, const indices_t* const csr_rows, const indptr_t* const sub_indptr, const indices_t* const csr_rows,
const weights_t* const weights, const indices_t* const indices, const weights_t* const sliced_weights, const indices_t* const indices,
const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) { const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) {
int64_t i = blockIdx.x * blockDim.x + threadIdx.x; int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = gridDim.x * blockDim.x; const int stride = gridDim.x * blockDim.x;
...@@ -65,7 +65,8 @@ __global__ void _ComputeRandoms( ...@@ -65,7 +65,8 @@ __global__ void _ComputeRandoms(
} }
const auto rnd = curand_uniform(&rng); const auto rnd = curand_uniform(&rng);
const auto prob = weights ? weights[in_idx] : static_cast<weights_t>(1); const auto prob =
sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
const auto exp_rnd = -__logf(rnd); const auto exp_rnd = -__logf(rnd);
const float_t adjusted_rnd = prob > 0 const float_t adjusted_rnd = prob > 0
? static_cast<float_t>(exp_rnd / prob) ? static_cast<float_t>(exp_rnd / prob)
...@@ -77,6 +78,13 @@ __global__ void _ComputeRandoms( ...@@ -77,6 +78,13 @@ __global__ void _ComputeRandoms(
} }
} }
struct IsPositive {
template <typename probs_t>
__host__ __device__ auto operator()(probs_t x) {
return x > 0;
}
};
template <typename indptr_t> template <typename indptr_t>
struct MinInDegreeFanout { struct MinInDegreeFanout {
const indptr_t* in_degree; const indptr_t* in_degree;
...@@ -152,7 +160,18 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -152,7 +160,18 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes); auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
auto in_degree = std::get<0>(in_degree_and_sliced_indptr); auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr); auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
auto sub_indptr = ExclusiveCumSum(in_degree); torch::Tensor sub_indptr;
// @todo mfbalin, refactor IndexSelectCSCImpl so that it does not have to take
// nodes as input
torch::optional<torch::Tensor> sliced_probs_or_mask;
if (probs_or_mask.has_value()) {
torch::Tensor sliced_probs_or_mask_tensor;
std::tie(sub_indptr, sliced_probs_or_mask_tensor) =
IndexSelectCSCImpl(indptr, probs_or_mask.value(), nodes);
sliced_probs_or_mask = sliced_probs_or_mask_tensor;
} else {
sub_indptr = ExclusiveCumSum(in_degree);
}
if (fanouts.size() > 1) { if (fanouts.size() > 1) {
torch::Tensor sliced_type_per_edge; torch::Tensor sliced_type_per_edge;
std::tie(sub_indptr, sliced_type_per_edge) = std::tie(sub_indptr, sliced_type_per_edge) =
...@@ -187,6 +206,29 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -187,6 +206,29 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] { indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = index_t; using indptr_t = index_t;
if (probs_or_mask.has_value()) { // Count nonzero probs into in_degree.
GRAPHBOLT_DISPATCH_ALL_TYPES(
probs_or_mask.value().scalar_type(),
"SampleNeighborsPositiveProbs", ([&] {
using probs_t = scalar_t;
auto is_nonzero = thrust::make_transform_iterator(
sliced_probs_or_mask.value().data_ptr<probs_t>(),
IsPositive{});
size_t tmp_storage_size = 0;
cub::DeviceSegmentedReduce::Sum(
nullptr, tmp_storage_size, is_nonzero,
in_degree.data_ptr<indptr_t>(), num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream);
auto tmp_storage =
allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceSegmentedReduce::Sum(
tmp_storage.get(), tmp_storage_size, is_nonzero,
in_degree.data_ptr<indptr_t>(), num_rows,
sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>() + 1, stream);
}));
}
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
auto sampled_degree = thrust::make_transform_iterator( auto sampled_degree = thrust::make_transform_iterator(
iota, MinInDegreeFanout<indptr_t>{ iota, MinInDegreeFanout<indptr_t>{
...@@ -246,10 +288,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -246,10 +288,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
probs_or_mask_scalar_type, "SampleNeighborsProbs", probs_or_mask_scalar_type, "SampleNeighborsProbs",
([&] { ([&] {
using probs_t = scalar_t; using probs_t = scalar_t;
probs_t* probs_ptr = nullptr; probs_t* sliced_probs_ptr = nullptr;
if (probs_or_mask.has_value()) { if (sliced_probs_or_mask.has_value()) {
probs_ptr = sliced_probs_ptr = sliced_probs_or_mask.value()
probs_or_mask.value().data_ptr<probs_t>(); .data_ptr<probs_t>();
} }
const indices_t* indices_ptr = const indices_t* indices_ptr =
layer ? indices.data_ptr<indices_t>() : nullptr; layer ? indices.data_ptr<indices_t>() : nullptr;
...@@ -261,7 +303,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -261,7 +303,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
_ComputeRandoms, grid, block, 0, stream, _ComputeRandoms, grid, block, 0, stream,
num_edges, sliced_indptr.data_ptr<indptr_t>(), num_edges, sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(), sub_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), probs_ptr, coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
indices_ptr, random_seed, randoms.get(), indices_ptr, random_seed, randoms.get(),
edge_id_segments.get()); edge_id_segments.get());
})); }));
......
...@@ -1797,10 +1797,6 @@ def test_sample_neighbors_fanouts( ...@@ -1797,10 +1797,6 @@ def test_sample_neighbors_fanouts(
assert subgraph.sampled_csc["n2:e2:n1"].indptr.size(0) == 2 assert subgraph.sampled_csc["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"replace, expected_sampled_num1, expected_sampled_num2", "replace, expected_sampled_num1, expected_sampled_num2",
[(False, 2, 2), (True, 4, 4)], [(False, 2, 2), (True, 4, 4)],
...@@ -1808,6 +1804,8 @@ def test_sample_neighbors_fanouts( ...@@ -1808,6 +1804,8 @@ def test_sample_neighbors_fanouts(
def test_sample_neighbors_replace( def test_sample_neighbors_replace(
replace, expected_sampled_num1, expected_sampled_num2 replace, expected_sampled_num1, expected_sampled_num2
): ):
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
"""Original graph in COO: """Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
...@@ -1966,14 +1964,12 @@ def test_sample_neighbors_return_eids_hetero(labor): ...@@ -1966,14 +1964,12 @@ def test_sample_neighbors_return_eids_hetero(labor):
) )
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [True, False]) @pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("probs_name", ["weight", "mask"]) @pytest.mark.parametrize("probs_name", ["weight", "mask"])
def test_sample_neighbors_probs(replace, labor, probs_name): def test_sample_neighbors_probs(replace, labor, probs_name):
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
"""Original graph in COO: """Original graph in COO:
1 0 1 0 1 1 0 1 0 1
1 0 1 1 0 1 0 1 1 0
...@@ -2020,10 +2016,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -2020,10 +2016,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
assert sampled_num == 4 assert sampled_num == 4
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [True, False]) @pytest.mark.parametrize("replace", [True, False])
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -2034,6 +2026,8 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -2034,6 +2026,8 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
], ],
) )
def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
# Initialize data. # Initialize data.
total_num_nodes = 5 total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
...@@ -2065,10 +2059,6 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): ...@@ -2065,10 +2059,6 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
assert sampled_num == 0 assert sampled_num == 0
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [False, True]) @pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -2089,6 +2079,8 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): ...@@ -2089,6 +2079,8 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
], ],
) )
def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name): def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
"""Original graph in COO: """Original graph in COO:
1 1 1 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
...@@ -2150,10 +2142,6 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name): ...@@ -2150,10 +2142,6 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
assert sampled_num == min(fanouts[0], 6) assert sampled_num == min(fanouts[0], 6)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Sampling with replacement not yet supported on GPU.",
)
@pytest.mark.parametrize("replace", [False, True]) @pytest.mark.parametrize("replace", [False, True])
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -2171,6 +2159,8 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name): ...@@ -2171,6 +2159,8 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
def test_sample_neighbors_hetero_pick_number( def test_sample_neighbors_hetero_pick_number(
fanouts, replace, labor, probs_name fanouts, replace, labor, probs_name
): ):
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
# Initialize data. # Initialize data.
total_num_nodes = 10 total_num_nodes = 10
total_num_edges = 9 total_num_edges = 9
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment