unique_and_compact_impl.cu 8.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/unique_and_compact_impl.cu
 * @brief Unique and compact operator implementation on CUDA.
 */
#include <c10/cuda/CUDAStream.h>
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/logical.h>
#include <thrust/reduce.h>
#include <thrust/remove.h>

#include <cub/cub.cuh>
18
#include <type_traits>
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

template <typename scalar_t>
struct EqualityFunc {
  const scalar_t* sorted_order;
  const scalar_t* found_locations;
  const scalar_t* searched_items;
  __host__ __device__ auto operator()(int64_t i) {
    return sorted_order[found_locations[i]] == searched_items[i];
  }
};

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#define DefineReductionFunction(reduce_fn, name)                               \
  template <typename scalar_iterator_t>                                        \
  auto name(const scalar_iterator_t input, int64_t size) {                     \
    auto allocator = cuda::GetAllocator();                                     \
    auto stream = cuda::GetCurrentStream();                                    \
    using scalar_t = std::remove_reference_t<decltype(input[0])>;              \
    cuda::CopyScalar<scalar_t> result;                                         \
    size_t workspace_size = 0;                                                 \
    reduce_fn(nullptr, workspace_size, input, result.get(), size, stream);     \
    auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);        \
    reduce_fn(                                                                 \
        tmp_storage.get(), workspace_size, input, result.get(), size, stream); \
    return result;                                                             \
  }

DefineReductionFunction(cub::DeviceReduce::Max, Max);
DefineReductionFunction(cub::DeviceReduce::Min, Min);

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
    const torch::Tensor src_ids, const torch::Tensor dst_ids,
    const torch::Tensor unique_dst_ids, int num_bits) {
  TORCH_CHECK(
      src_ids.scalar_type() == dst_ids.scalar_type() &&
          dst_ids.scalar_type() == unique_dst_ids.scalar_type(),
      "Dtypes of tensors passed to UniqueAndCompact need to be identical.");
  auto allocator = cuda::GetAllocator();
  auto stream = cuda::GetCurrentStream();
  const auto exec_policy = thrust::cuda::par_nosync(allocator).on(stream);
  return AT_DISPATCH_INTEGRAL_TYPES(
      src_ids.scalar_type(), "unique_and_compact", ([&] {
        auto src_ids_ptr = src_ids.data_ptr<scalar_t>();
        auto dst_ids_ptr = dst_ids.data_ptr<scalar_t>();
        auto unique_dst_ids_ptr = unique_dst_ids.data_ptr<scalar_t>();

70
71
72
73
74
75
76
        // If num_bits is not given, compute maximum vertex ids to compute
        // num_bits later to speedup the expensive sort operations.
        cuda::CopyScalar<scalar_t> max_id_src;
        cuda::CopyScalar<scalar_t> max_id_dst;
        if (num_bits == 0) {
          max_id_src = Max(src_ids_ptr, src_ids.size(0));
          max_id_dst = Max(unique_dst_ids_ptr, unique_dst_ids.size(0));
77
78
79
80
        }

        // Sort the unique_dst_ids tensor.
        auto sorted_unique_dst_ids =
81
82
83
            Sort<false>(unique_dst_ids_ptr, unique_dst_ids.size(0), num_bits);
        auto sorted_unique_dst_ids_ptr =
            sorted_unique_dst_ids.data_ptr<scalar_t>();
84
85
86
87

        // Mark dst nodes in the src_ids tensor.
        auto is_dst = allocator.AllocateStorage<bool>(src_ids.size(0));
        thrust::binary_search(
88
89
            exec_policy, sorted_unique_dst_ids_ptr,
            sorted_unique_dst_ids_ptr + unique_dst_ids.size(0), src_ids_ptr,
90
91
92
            src_ids_ptr + src_ids.size(0), is_dst.get());

        // Filter the non-dst nodes in the src_ids tensor, hence only_src.
93
94
95
        auto only_src =
            torch::empty(src_ids.size(0), sorted_unique_dst_ids.options());
        {
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
          auto is_src = thrust::make_transform_iterator(
              is_dst.get(), thrust::logical_not<bool>{});
          cuda::CopyScalar<int64_t> only_src_size;
          size_t workspace_size = 0;
          cub::DeviceSelect::Flagged(
              nullptr, workspace_size, src_ids_ptr, is_src,
              only_src.data_ptr<scalar_t>(), only_src_size.get(),
              src_ids.size(0), stream);
          auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
          cub::DeviceSelect::Flagged(
              tmp_storage.get(), workspace_size, src_ids_ptr, is_src,
              only_src.data_ptr<scalar_t>(), only_src_size.get(),
              src_ids.size(0), stream);
          stream.synchronize();
          only_src = only_src.slice(0, 0, static_cast<int64_t>(only_src_size));
111
112
        }

113
114
115
116
117
118
119
120
121
122
        // The code block above synchronizes, ensuring safe access to max_id_src
        // and max_id_dst.
        if (num_bits == 0) {
          num_bits = cuda::NumberOfBits(
              1 + std::max(
                      static_cast<scalar_t>(max_id_src),
                      static_cast<scalar_t>(max_id_dst)));
        }

        // Sort the only_src tensor so that we can unique it later.
123
124
125
126
127
        auto sorted_only_src = Sort<false>(
            only_src.data_ptr<scalar_t>(), only_src.size(0), num_bits);

        auto unique_only_src =
            torch::empty(only_src.size(0), src_ids.options());
128
129
130
        auto unique_only_src_ptr = unique_only_src.data_ptr<scalar_t>();

        {  // Compute the unique operation on the only_src tensor.
131
132
133
          cuda::CopyScalar<int64_t> unique_only_src_size;
          size_t workspace_size = 0;
          CUDA_CALL(cub::DeviceSelect::Unique(
134
              nullptr, workspace_size, sorted_only_src.data_ptr<scalar_t>(),
135
136
137
138
139
140
141
142
              unique_only_src_ptr, unique_only_src_size.get(), only_src.size(0),
              stream));
          auto tmp_storage = allocator.AllocateStorage<char>(workspace_size);
          CUDA_CALL(cub::DeviceSelect::Unique(
              tmp_storage.get(), workspace_size,
              sorted_only_src.data_ptr<scalar_t>(), unique_only_src_ptr,
              unique_only_src_size.get(), only_src.size(0), stream));
          stream.synchronize();
143
          unique_only_src = unique_only_src.slice(
144
              0, 0, static_cast<int64_t>(unique_only_src_size));
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        }

        auto real_order = torch::cat({unique_dst_ids, unique_only_src});
        // Sort here so that binary search can be used to lookup new_ids.
        auto [sorted_order, new_ids] = Sort(real_order, num_bits);
        auto sorted_order_ptr = sorted_order.data_ptr<scalar_t>();
        auto new_ids_ptr = new_ids.data_ptr<int64_t>();
        // Holds the found locations of the src and dst ids in the sorted_order.
        // Later is used to lookup the new ids of the src_ids and dst_ids
        // tensors.
        auto new_dst_ids_loc =
            allocator.AllocateStorage<scalar_t>(dst_ids.size(0));
        thrust::lower_bound(
            exec_policy, sorted_order_ptr,
            sorted_order_ptr + sorted_order.size(0), dst_ids_ptr,
            dst_ids_ptr + dst_ids.size(0), new_dst_ids_loc.get());
161
162
163
164

        cuda::CopyScalar<bool> all_exist;
        // Check if unique_dst_ids includes all dst_ids.
        if (dst_ids.size(0) > 0) {
165
166
167
168
          thrust::counting_iterator<int64_t> iota(0);
          auto equal_it = thrust::make_transform_iterator(
              iota, EqualityFunc<scalar_t>{
                        sorted_order_ptr, new_dst_ids_loc.get(), dst_ids_ptr});
169
170
          all_exist = Min(equal_it, dst_ids.size(0));
          all_exist.record();
171
172
        }

173
174
175
176
177
178
179
        auto new_src_ids_loc =
            allocator.AllocateStorage<scalar_t>(src_ids.size(0));
        thrust::lower_bound(
            exec_policy, sorted_order_ptr,
            sorted_order_ptr + sorted_order.size(0), src_ids_ptr,
            src_ids_ptr + src_ids.size(0), new_src_ids_loc.get());

180
181
182
183
184
185
186
        // Finally, lookup the new compact ids of the src and dst tensors via
        // gather operations.
        auto new_src_ids = torch::empty_like(src_ids);
        thrust::gather(
            exec_policy, new_src_ids_loc.get(),
            new_src_ids_loc.get() + src_ids.size(0),
            new_ids.data_ptr<int64_t>(), new_src_ids.data_ptr<scalar_t>());
187
188
189
190
191
        // Perform check before we gather for the dst indices.
        if (dst_ids.size(0) > 0 && !static_cast<bool>(all_exist)) {
          throw std::out_of_range("Some ids not found.");
        }
        auto new_dst_ids = torch::empty_like(dst_ids);
192
193
194
195
196
197
198
199
200
201
        thrust::gather(
            exec_policy, new_dst_ids_loc.get(),
            new_dst_ids_loc.get() + dst_ids.size(0),
            new_ids.data_ptr<int64_t>(), new_dst_ids.data_ptr<scalar_t>());
        return std::make_tuple(real_order, new_src_ids, new_dst_ids);
      }));
}

}  // namespace ops
}  // namespace graphbolt