isin.hip 1.04 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
9
10
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/isin.cu
 * @brief IsIn operator implementation on CUDA.
 */
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>

sangwzh's avatar
sangwzh committed
11
#include "common.h"
12
13
14
15
16

namespace graphbolt {
namespace ops {

torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
17
  auto sorted_test_elements = Sort<false>(test_elements);
18
19
20
21
  auto result = torch::empty_like(elements, torch::kBool);

  AT_DISPATCH_INTEGRAL_TYPES(
      elements.scalar_type(), "IsInOperation", ([&] {
22
23
        THRUST_CALL(
            binary_search, sorted_test_elements.data_ptr<scalar_t>(),
24
25
26
27
28
29
30
31
32
33
34
            sorted_test_elements.data_ptr<scalar_t>() +
                sorted_test_elements.size(0),
            elements.data_ptr<scalar_t>(),
            elements.data_ptr<scalar_t>() + elements.size(0),
            result.data_ptr<bool>());
      }));
  return result;
}

}  // namespace ops
}  // namespace graphbolt