Unverified Commit 437139f5 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] `gb.isin` implementation (#6829)

parent 7d5d576e
...@@ -24,6 +24,23 @@ namespace ops { ...@@ -24,6 +24,23 @@ namespace ops {
std::pair<torch::Tensor, torch::Tensor> Sort( std::pair<torch::Tensor, torch::Tensor> Sort(
torch::Tensor input, int num_bits = 0); torch::Tensor input, int num_bits = 0);
/**
* @brief Tests if each element of elements is in test_elements. Returns a
* boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise. Enhance torch.isin by implementing
* multi-threaded searching, as detailed in the documentation at
* https://pytorch.org/docs/stable/generated/torch.isin.html."
*
* @param elements Input elements
* @param test_elements Values against which to test for each input element.
*
* @return
* A boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise.
*
*/
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
/** /**
* @brief Select columns for a sparse matrix in a CSC format according to nodes * @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor. * tensor.
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/isin.cu
* @brief IsIn operator implementation on CUDA.
*/
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>
#include <cub/cub.cuh>
#include "./common.h"
namespace graphbolt {
namespace ops {
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
auto sorted_test_elements = Sort(test_elements).first;
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
auto result = torch::empty_like(elements, torch::kBool);
AT_DISPATCH_INTEGRAL_TYPES(
elements.scalar_type(), "IsInOperation", ([&] {
thrust::binary_search(
exec_policy, sorted_test_elements.data_ptr<scalar_t>(),
sorted_test_elements.data_ptr<scalar_t>() +
sorted_test_elements.size(0),
elements.data_ptr<scalar_t>(),
elements.data_ptr<scalar_t>() + elements.size(0),
result.data_ptr<bool>());
}));
return result;
}
} // namespace ops
} // namespace graphbolt
...@@ -5,8 +5,12 @@ ...@@ -5,8 +5,12 @@
* @brief Isin op. * @brief Isin op.
*/ */
#include <graphbolt/cuda_ops.h>
#include <graphbolt/isin.h> #include <graphbolt/isin.h>
#include "./macro.h"
#include "./utils.h"
namespace { namespace {
static constexpr int kSearchGrainSize = 4096; static constexpr int kSearchGrainSize = 4096;
} // namespace } // namespace
...@@ -14,7 +18,7 @@ static constexpr int kSearchGrainSize = 4096; ...@@ -14,7 +18,7 @@ static constexpr int kSearchGrainSize = 4096;
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
torch::Tensor IsIn( torch::Tensor IsInCPU(
const torch::Tensor& elements, const torch::Tensor& test_elements) { const torch::Tensor& elements, const torch::Tensor& test_elements) {
torch::Tensor sorted_test_elements; torch::Tensor sorted_test_elements;
std::tie(sorted_test_elements, std::ignore) = test_elements.sort( std::tie(sorted_test_elements, std::ignore) = test_elements.sort(
...@@ -41,5 +45,17 @@ torch::Tensor IsIn( ...@@ -41,5 +45,17 @@ torch::Tensor IsIn(
})); }));
return result; return result;
} }
torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
if (utils::is_accessible_from_gpu(elements) &&
utils::is_accessible_from_gpu(test_elements)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IsInOperation",
{ return ops::IsIn(elements, test_elements); });
} else {
return IsInCPU(elements, test_elements);
}
}
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
...@@ -124,28 +124,30 @@ def test_etype_str_to_tuple(): ...@@ -124,28 +124,30 @@ def test_etype_str_to_tuple():
def test_isin(): def test_isin():
elements = torch.tensor([2, 3, 5, 5, 20, 13, 11]) elements = torch.tensor([2, 3, 5, 5, 20, 13, 11], device=F.ctx())
test_elements = torch.tensor([2, 5]) test_elements = torch.tensor([2, 5], device=F.ctx())
res = gb.isin(elements, test_elements) res = gb.isin(elements, test_elements)
expected = torch.tensor([True, False, True, True, False, False, False]) expected = torch.tensor(
[True, False, True, True, False, False, False], device=F.ctx()
)
assert torch.equal(res, expected) assert torch.equal(res, expected)
def test_isin_big_data(): def test_isin_big_data():
elements = torch.randint(0, 10000, (10000000,)) elements = torch.randint(0, 10000, (10000000,), device=F.ctx())
test_elements = torch.randint(0, 10000, (500000,)) test_elements = torch.randint(0, 10000, (500000,), device=F.ctx())
res = gb.isin(elements, test_elements) res = gb.isin(elements, test_elements)
expected = torch.isin(elements, test_elements) expected = torch.isin(elements, test_elements)
assert torch.equal(res, expected) assert torch.equal(res, expected)
def test_isin_non_1D_dim(): def test_isin_non_1D_dim():
elements = torch.tensor([[2, 3], [5, 5], [20, 13]]) elements = torch.tensor([[2, 3], [5, 5], [20, 13]], device=F.ctx())
test_elements = torch.tensor([2, 5]) test_elements = torch.tensor([2, 5], device=F.ctx())
with pytest.raises(Exception): with pytest.raises(Exception):
gb.isin(elements, test_elements) gb.isin(elements, test_elements)
elements = torch.tensor([2, 3, 5, 5, 20, 13]) elements = torch.tensor([2, 3, 5, 5, 20, 13], device=F.ctx())
test_elements = torch.tensor([[2, 5]]) test_elements = torch.tensor([[2, 5]], device=F.ctx())
with pytest.raises(Exception): with pytest.raises(Exception):
gb.isin(elements, test_elements) gb.isin(elements, test_elements)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment