"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "97ed428d3c82010c652686bd8115959163283d33"
Unverified Commit 437139f5 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] `gb.isin` implementation (#6829)

parent 7d5d576e
......@@ -24,6 +24,23 @@ namespace ops {
std::pair<torch::Tensor, torch::Tensor> Sort(
torch::Tensor input, int num_bits = 0);
/**
* @brief Tests if each element of elements is in test_elements. Returns a
* boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise. Enhance torch.isin by implementing
* multi-threaded searching, as detailed in the documentation at
* https://pytorch.org/docs/stable/generated/torch.isin.html."
*
* @param elements Input elements
* @param test_elements Values against which to test for each input element.
*
* @return
* A boolean tensor of the same shape as elements that is True for elements
* in test_elements and False otherwise.
*
*/
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
/**
* @brief Select columns for a sparse matrix in a CSC format according to nodes
* tensor.
......
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/isin.cu
* @brief IsIn operator implementation on CUDA.
*/
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>
#include <cub/cub.cuh>
#include "./common.h"
namespace graphbolt {
namespace ops {
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
auto sorted_test_elements = Sort(test_elements).first;
auto allocator = cuda::GetAllocator();
auto stream = cuda::GetCurrentStream();
const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
auto result = torch::empty_like(elements, torch::kBool);
AT_DISPATCH_INTEGRAL_TYPES(
elements.scalar_type(), "IsInOperation", ([&] {
thrust::binary_search(
exec_policy, sorted_test_elements.data_ptr<scalar_t>(),
sorted_test_elements.data_ptr<scalar_t>() +
sorted_test_elements.size(0),
elements.data_ptr<scalar_t>(),
elements.data_ptr<scalar_t>() + elements.size(0),
result.data_ptr<bool>());
}));
return result;
}
} // namespace ops
} // namespace graphbolt
......@@ -5,8 +5,12 @@
* @brief Isin op.
*/
#include <graphbolt/cuda_ops.h>
#include <graphbolt/isin.h>
#include "./macro.h"
#include "./utils.h"
namespace {
static constexpr int kSearchGrainSize = 4096;
} // namespace
......@@ -14,7 +18,7 @@ static constexpr int kSearchGrainSize = 4096;
namespace graphbolt {
namespace sampling {
torch::Tensor IsIn(
torch::Tensor IsInCPU(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
torch::Tensor sorted_test_elements;
std::tie(sorted_test_elements, std::ignore) = test_elements.sort(
......@@ -41,5 +45,17 @@ torch::Tensor IsIn(
}));
return result;
}
torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
if (utils::is_accessible_from_gpu(elements) &&
utils::is_accessible_from_gpu(test_elements)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IsInOperation",
{ return ops::IsIn(elements, test_elements); });
} else {
return IsInCPU(elements, test_elements);
}
}
} // namespace sampling
} // namespace graphbolt
......@@ -124,28 +124,30 @@ def test_etype_str_to_tuple():
def test_isin():
elements = torch.tensor([2, 3, 5, 5, 20, 13, 11])
test_elements = torch.tensor([2, 5])
elements = torch.tensor([2, 3, 5, 5, 20, 13, 11], device=F.ctx())
test_elements = torch.tensor([2, 5], device=F.ctx())
res = gb.isin(elements, test_elements)
expected = torch.tensor([True, False, True, True, False, False, False])
expected = torch.tensor(
[True, False, True, True, False, False, False], device=F.ctx()
)
assert torch.equal(res, expected)
def test_isin_big_data():
elements = torch.randint(0, 10000, (10000000,))
test_elements = torch.randint(0, 10000, (500000,))
elements = torch.randint(0, 10000, (10000000,), device=F.ctx())
test_elements = torch.randint(0, 10000, (500000,), device=F.ctx())
res = gb.isin(elements, test_elements)
expected = torch.isin(elements, test_elements)
assert torch.equal(res, expected)
def test_isin_non_1D_dim():
elements = torch.tensor([[2, 3], [5, 5], [20, 13]])
test_elements = torch.tensor([2, 5])
elements = torch.tensor([[2, 3], [5, 5], [20, 13]], device=F.ctx())
test_elements = torch.tensor([2, 5], device=F.ctx())
with pytest.raises(Exception):
gb.isin(elements, test_elements)
elements = torch.tensor([2, 3, 5, 5, 20, 13])
test_elements = torch.tensor([[2, 5]])
elements = torch.tensor([2, 3, 5, 5, 20, 13], device=F.ctx())
test_elements = torch.tensor([[2, 5]], device=F.ctx())
with pytest.raises(Exception):
gb.isin(elements, test_elements)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment