unique_and_compact_impl.cu 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/unique_and_compact_impl.cu
 * @brief Unique and compact operator implementation on CUDA.
 */
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/logical.h>

#include <cub/cub.cuh>
14
#include <mutex>
15
#include <type_traits>
16
#include <unordered_map>
17
18

#include "./common.h"
19
#include "./extension/unique_and_compact.h"
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#include "./utils.h"

namespace graphbolt {
namespace ops {

template <typename scalar_t>
struct EqualityFunc {
  const scalar_t* sorted_order;
  const scalar_t* found_locations;
  const scalar_t* searched_items;
  __host__ __device__ auto operator()(int64_t i) {
    return sorted_order[found_locations[i]] == searched_items[i];
  }
};

35
36
37
38
39
40
41
#define DefineCubReductionFunction(cub_reduce_fn, name)           \
  template <typename scalar_iterator_t>                           \
  auto name(const scalar_iterator_t input, int64_t size) {        \
    using scalar_t = std::remove_reference_t<decltype(input[0])>; \
    cuda::CopyScalar<scalar_t> result;                            \
    CUB_CALL(cub_reduce_fn, input, result.get(), size);           \
    return result;                                                \
42
43
  }

44
45
DefineCubReductionFunction(DeviceReduce::Max, Max);
DefineCubReductionFunction(DeviceReduce::Min, Min);
46

47
48
49
50
51
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatchedSortBased(
    const std::vector<torch::Tensor>& src_ids,
    const std::vector<torch::Tensor>& dst_ids,
    const std::vector<torch::Tensor>& unique_dst_ids, int num_bits) {
52
53
  auto allocator = cuda::GetAllocator();
  auto stream = cuda::GetCurrentStream();
54
55
56
57
58
59
60
61
62
63
  auto scalar_type = src_ids.at(0).scalar_type();
  return AT_DISPATCH_INDEX_TYPES(
      scalar_type, "unique_and_compact", ([&] {
        std::vector<index_t*> src_ids_ptr, dst_ids_ptr, unique_dst_ids_ptr;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          src_ids_ptr.emplace_back(src_ids[i].data_ptr<index_t>());
          dst_ids_ptr.emplace_back(dst_ids[i].data_ptr<index_t>());
          unique_dst_ids_ptr.emplace_back(
              unique_dst_ids[i].data_ptr<index_t>());
        }
64

65
66
        // If num_bits is not given, compute maximum vertex ids to compute
        // num_bits later to speedup the expensive sort operations.
67
68
69
70
71
72
        std::vector<cuda::CopyScalar<index_t>> max_id_src;
        std::vector<cuda::CopyScalar<index_t>> max_id_dst;
        for (std::size_t i = 0; num_bits == 0 && i < src_ids.size(); i++) {
          max_id_src.emplace_back(Max(src_ids_ptr[i], src_ids[i].size(0)));
          max_id_dst.emplace_back(
              Max(unique_dst_ids_ptr[i], unique_dst_ids[i].size(0)));
73
74
75
        }

        // Sort the unique_dst_ids tensor.
76
77
78
79
80
81
82
83
        std::vector<torch::Tensor> sorted_unique_dst_ids;
        std::vector<index_t*> sorted_unique_dst_ids_ptr;
        for (std::size_t i = 0; i < unique_dst_ids.size(); i++) {
          sorted_unique_dst_ids.emplace_back(Sort<false>(
              unique_dst_ids_ptr[i], unique_dst_ids[i].size(0), num_bits));
          sorted_unique_dst_ids_ptr.emplace_back(
              sorted_unique_dst_ids[i].data_ptr<index_t>());
        }
84
85

        // Mark dst nodes in the src_ids tensor.
86
87
88
89
90
91
92
93
94
95
        std::vector<decltype(allocator.AllocateStorage<bool>(0))> is_dst;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          is_dst.emplace_back(
              allocator.AllocateStorage<bool>(src_ids[i].size(0)));
          THRUST_CALL(
              binary_search, sorted_unique_dst_ids_ptr[i],
              sorted_unique_dst_ids_ptr[i] + unique_dst_ids[i].size(0),
              src_ids_ptr[i], src_ids_ptr[i] + src_ids[i].size(0),
              is_dst[i].get());
        }
96
97

        // Filter the non-dst nodes in the src_ids tensor, hence only_src.
98
        std::vector<torch::Tensor> only_src;
99
        {
100
101
102
103
104
105
106
107
108
109
110
111
          std::vector<cuda::CopyScalar<int64_t>> only_src_size;
          for (std::size_t i = 0; i < src_ids.size(); i++) {
            only_src.emplace_back(torch::empty(
                src_ids[i].size(0), sorted_unique_dst_ids[i].options()));
            auto is_src = thrust::make_transform_iterator(
                is_dst[i].get(), thrust::logical_not<bool>{});
            only_src_size.emplace_back(cuda::CopyScalar<int64_t>{});
            CUB_CALL(
                DeviceSelect::Flagged, src_ids_ptr[i], is_src,
                only_src[i].data_ptr<index_t>(), only_src_size[i].get(),
                src_ids[i].size(0));
          }
112
          stream.synchronize();
113
114
115
116
          for (std::size_t i = 0; i < only_src.size(); i++) {
            only_src[i] =
                only_src[i].slice(0, 0, static_cast<int64_t>(only_src_size[i]));
          }
117
118
        }

119
120
        // The code block above synchronizes, ensuring safe access to
        // max_id_src and max_id_dst.
121
        if (num_bits == 0) {
122
123
124
125
126
127
          index_t max_id = 0;
          for (std::size_t i = 0; i < max_id_src.size(); i++) {
            max_id = std::max(max_id, static_cast<index_t>(max_id_src[i]));
            max_id = std::max(max_id, static_cast<index_t>(max_id_dst[i]));
          }
          num_bits = cuda::NumberOfBits(1ll + max_id);
128
129
130
        }

        // Sort the only_src tensor so that we can unique it later.
131
132
133
134
135
        std::vector<torch::Tensor> sorted_only_src;
        for (auto& only_src_i : only_src) {
          sorted_only_src.emplace_back(Sort<false>(
              only_src_i.data_ptr<index_t>(), only_src_i.size(0), num_bits));
        }
136

137
138
        std::vector<torch::Tensor> unique_only_src;
        std::vector<index_t*> unique_only_src_ptr;
139

140
141
142
143
144
145
146
147
        std::vector<cuda::CopyScalar<int64_t>> unique_only_src_size;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          // Compute the unique operation on the only_src tensor.
          unique_only_src.emplace_back(
              torch::empty(only_src[i].size(0), src_ids[i].options()));
          unique_only_src_ptr.emplace_back(
              unique_only_src[i].data_ptr<index_t>());
          unique_only_src_size.emplace_back(cuda::CopyScalar<int64_t>{});
148
          CUB_CALL(
149
150
151
152
153
154
155
156
              DeviceSelect::Unique, sorted_only_src[i].data_ptr<index_t>(),
              unique_only_src_ptr[i], unique_only_src_size[i].get(),
              only_src[i].size(0));
        }
        stream.synchronize();
        for (std::size_t i = 0; i < unique_only_src.size(); i++) {
          unique_only_src[i] = unique_only_src[i].slice(
              0, 0, static_cast<int64_t>(unique_only_src_size[i]));
157
158
        }

159
160
161
162
163
        std::vector<torch::Tensor> real_order;
        for (std::size_t i = 0; i < unique_dst_ids.size(); i++) {
          real_order.emplace_back(
              torch::cat({unique_dst_ids[i], unique_only_src[i]}));
        }
164
        // Sort here so that binary search can be used to lookup new_ids.
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        std::vector<torch::Tensor> sorted_order, new_ids;
        std::vector<index_t*> sorted_order_ptr;
        std::vector<int64_t*> new_ids_ptr;
        for (std::size_t i = 0; i < real_order.size(); i++) {
          auto [sorted_order_i, new_ids_i] = Sort(real_order[i], num_bits);
          sorted_order_ptr.emplace_back(sorted_order_i.data_ptr<index_t>());
          new_ids_ptr.emplace_back(new_ids_i.data_ptr<int64_t>());
          sorted_order.emplace_back(std::move(sorted_order_i));
          new_ids.emplace_back(std::move(new_ids_i));
        }
        // Holds the found locations of the src and dst ids in the
        // sorted_order. Later is used to lookup the new ids of the src_ids
        // and dst_ids tensors.
        std::vector<decltype(allocator.AllocateStorage<index_t>(0))>
            new_dst_ids_loc;
        for (std::size_t i = 0; i < sorted_order.size(); i++) {
          new_dst_ids_loc.emplace_back(
              allocator.AllocateStorage<index_t>(dst_ids[i].size(0)));
          THRUST_CALL(
              lower_bound, sorted_order_ptr[i],
              sorted_order_ptr[i] + sorted_order[i].size(0), dst_ids_ptr[i],
              dst_ids_ptr[i] + dst_ids[i].size(0), new_dst_ids_loc[i].get());
        }

        std::vector<cuda::CopyScalar<bool>> all_exist;
        at::cuda::CUDAEvent all_exist_event;
        bool should_record = false;
192
        // Check if unique_dst_ids includes all dst_ids.
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        for (std::size_t i = 0; i < dst_ids.size(); i++) {
          if (dst_ids[i].size(0) > 0) {
            thrust::counting_iterator<int64_t> iota(0);
            auto equal_it = thrust::make_transform_iterator(
                iota, EqualityFunc<index_t>{
                          sorted_order_ptr[i], new_dst_ids_loc[i].get(),
                          dst_ids_ptr[i]});
            all_exist.emplace_back(Min(equal_it, dst_ids[i].size(0)));
            should_record = true;
          } else {
            all_exist.emplace_back(cuda::CopyScalar<bool>{});
          }
        }
        if (should_record) all_exist_event.record();

        std::vector<decltype(allocator.AllocateStorage<index_t>(0))>
            new_src_ids_loc;
        for (std::size_t i = 0; i < sorted_order.size(); i++) {
          new_src_ids_loc.emplace_back(
              allocator.AllocateStorage<index_t>(src_ids[i].size(0)));
          THRUST_CALL(
              lower_bound, sorted_order_ptr[i],
              sorted_order_ptr[i] + sorted_order[i].size(0), src_ids_ptr[i],
              src_ids_ptr[i] + src_ids[i].size(0), new_src_ids_loc[i].get());
        }

        // Finally, lookup the new compact ids of the src and dst tensors
        // via gather operations.
        std::vector<torch::Tensor> new_src_ids;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          new_src_ids.emplace_back(torch::empty_like(src_ids[i]));
          THRUST_CALL(
              gather, new_src_ids_loc[i].get(),
              new_src_ids_loc[i].get() + src_ids[i].size(0),
              new_ids[i].data_ptr<int64_t>(),
              new_src_ids[i].data_ptr<index_t>());
        }
230
        // Perform check before we gather for the dst indices.
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        for (std::size_t i = 0; i < dst_ids.size(); i++) {
          if (dst_ids[i].size(0) > 0) {
            if (should_record) {
              all_exist_event.synchronize();
              should_record = false;
            }
            if (!static_cast<bool>(all_exist[i])) {
              throw std::out_of_range("Some ids not found.");
            }
          }
        }
        std::vector<torch::Tensor> new_dst_ids;
        for (std::size_t i = 0; i < dst_ids.size(); i++) {
          new_dst_ids.emplace_back(torch::empty_like(dst_ids[i]));
          THRUST_CALL(
              gather, new_dst_ids_loc[i].get(),
              new_dst_ids_loc[i].get() + dst_ids[i].size(0),
              new_ids[i].data_ptr<int64_t>(),
              new_dst_ids[i].data_ptr<index_t>());
        }
        std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
            results;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          results.emplace_back(
              std::move(real_order[i]), std::move(new_src_ids[i]),
              std::move(new_dst_ids[i]));
        }
        return results;
259
260
261
      }));
}

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatched(
    const std::vector<torch::Tensor>& src_ids,
    const std::vector<torch::Tensor>& dst_ids,
    const std::vector<torch::Tensor>& unique_dst_ids, int num_bits) {
  auto dev_id = cuda::GetCurrentStream().device_index();
  static std::mutex mtx;
  static std::unordered_map<decltype(dev_id), int> compute_capability_cache;
  const auto compute_capability_major = [&] {
    std::lock_guard lock(mtx);
    auto it = compute_capability_cache.find(dev_id);
    if (it != compute_capability_cache.end()) {
      return it->second;
    } else {
      int major;
      CUDA_DRIVER_CHECK(cuDeviceGetAttribute(
          &major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, dev_id));
      return compute_capability_cache[dev_id] = major;
    }
  }();
  if (compute_capability_major >= 7) {
    // Utilizes a hash table based implementation, the mapped id of a vertex
    // will be monotonically increasing as the first occurrence index of it in
    // torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic.
    return UniqueAndCompactBatchedHashMapBased(
        src_ids, dst_ids, unique_dst_ids);
  }
  // Utilizes a sort based algorithm, the mapped id of a vertex part of the
  // src_ids but not part of the unique_dst_ids will be monotonically increasing
  // as the actual vertex id increases. Thus, it is deterministic.
  return UniqueAndCompactBatchedSortBased(
      src_ids, dst_ids, unique_dst_ids, num_bits);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
    const torch::Tensor src_ids, const torch::Tensor dst_ids,
    const torch::Tensor unique_dst_ids, int num_bits) {
  return UniqueAndCompactBatched(
      {src_ids}, {dst_ids}, {unique_dst_ids}, num_bits)[0];
}

303
304
}  // namespace ops
}  // namespace graphbolt