unique_and_compact_impl.hip 13.1 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/unique_and_compact_impl.cu
 * @brief Unique and compact operator implementation on CUDA.
 */
sangwzh's avatar
sangwzh committed
8
#include <hip/hip_runtime.h>
9
10
11
12
13
14
#include <graphbolt/cuda_ops.h>
#include <thrust/binary_search.h>
#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/logical.h>

15
#include <mutex>
sangwzh's avatar
sangwzh committed
16
#include <hipcub/hipcub.hpp>
17
#include <type_traits>
18
#include <unordered_map>
19
20

#include "./common.h"
21
#include "./extension/unique_and_compact.h"
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include "./utils.h"

namespace graphbolt {
namespace ops {

template <typename scalar_t>
struct EqualityFunc {
  const scalar_t* sorted_order;
  const scalar_t* found_locations;
  const scalar_t* searched_items;
  __host__ __device__ auto operator()(int64_t i) {
    return sorted_order[found_locations[i]] == searched_items[i];
  }
};

37
38
39
40
41
42
43
#define DefineCubReductionFunction(cub_reduce_fn, name)           \
  template <typename scalar_iterator_t>                           \
  auto name(const scalar_iterator_t input, int64_t size) {        \
    using scalar_t = std::remove_reference_t<decltype(input[0])>; \
    cuda::CopyScalar<scalar_t> result;                            \
    CUB_CALL(cub_reduce_fn, input, result.get(), size);           \
    return result;                                                \
44
45
  }

46
47
DefineCubReductionFunction(DeviceReduce::Max, Max);
DefineCubReductionFunction(DeviceReduce::Min, Min);
48

49
50
51
52
53
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatchedSortBased(
    const std::vector<torch::Tensor>& src_ids,
    const std::vector<torch::Tensor>& dst_ids,
    const std::vector<torch::Tensor>& unique_dst_ids, int num_bits) {
54
55
  auto allocator = cuda::GetAllocator();
  auto stream = cuda::GetCurrentStream();
56
57
58
59
60
61
62
63
64
65
  auto scalar_type = src_ids.at(0).scalar_type();
  return AT_DISPATCH_INDEX_TYPES(
      scalar_type, "unique_and_compact", ([&] {
        std::vector<index_t*> src_ids_ptr, dst_ids_ptr, unique_dst_ids_ptr;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          src_ids_ptr.emplace_back(src_ids[i].data_ptr<index_t>());
          dst_ids_ptr.emplace_back(dst_ids[i].data_ptr<index_t>());
          unique_dst_ids_ptr.emplace_back(
              unique_dst_ids[i].data_ptr<index_t>());
        }
66

67
68
        // If num_bits is not given, compute maximum vertex ids to compute
        // num_bits later to speedup the expensive sort operations.
69
70
71
72
73
74
        std::vector<cuda::CopyScalar<index_t>> max_id_src;
        std::vector<cuda::CopyScalar<index_t>> max_id_dst;
        for (std::size_t i = 0; num_bits == 0 && i < src_ids.size(); i++) {
          max_id_src.emplace_back(Max(src_ids_ptr[i], src_ids[i].size(0)));
          max_id_dst.emplace_back(
              Max(unique_dst_ids_ptr[i], unique_dst_ids[i].size(0)));
75
76
77
        }

        // Sort the unique_dst_ids tensor.
78
79
80
81
82
83
84
85
        std::vector<torch::Tensor> sorted_unique_dst_ids;
        std::vector<index_t*> sorted_unique_dst_ids_ptr;
        for (std::size_t i = 0; i < unique_dst_ids.size(); i++) {
          sorted_unique_dst_ids.emplace_back(Sort<false>(
              unique_dst_ids_ptr[i], unique_dst_ids[i].size(0), num_bits));
          sorted_unique_dst_ids_ptr.emplace_back(
              sorted_unique_dst_ids[i].data_ptr<index_t>());
        }
86
87

        // Mark dst nodes in the src_ids tensor.
88
89
90
91
92
93
94
95
96
97
        std::vector<decltype(allocator.AllocateStorage<bool>(0))> is_dst;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          is_dst.emplace_back(
              allocator.AllocateStorage<bool>(src_ids[i].size(0)));
          THRUST_CALL(
              binary_search, sorted_unique_dst_ids_ptr[i],
              sorted_unique_dst_ids_ptr[i] + unique_dst_ids[i].size(0),
              src_ids_ptr[i], src_ids_ptr[i] + src_ids[i].size(0),
              is_dst[i].get());
        }
98
99

        // Filter the non-dst nodes in the src_ids tensor, hence only_src.
100
        std::vector<torch::Tensor> only_src;
101
        {
102
103
104
105
106
107
108
109
110
111
112
113
          std::vector<cuda::CopyScalar<int64_t>> only_src_size;
          for (std::size_t i = 0; i < src_ids.size(); i++) {
            only_src.emplace_back(torch::empty(
                src_ids[i].size(0), sorted_unique_dst_ids[i].options()));
            auto is_src = thrust::make_transform_iterator(
                is_dst[i].get(), thrust::logical_not<bool>{});
            only_src_size.emplace_back(cuda::CopyScalar<int64_t>{});
            CUB_CALL(
                DeviceSelect::Flagged, src_ids_ptr[i], is_src,
                only_src[i].data_ptr<index_t>(), only_src_size[i].get(),
                src_ids[i].size(0));
          }
114
          stream.synchronize();
115
116
117
118
          for (std::size_t i = 0; i < only_src.size(); i++) {
            only_src[i] =
                only_src[i].slice(0, 0, static_cast<int64_t>(only_src_size[i]));
          }
119
120
        }

121
122
        // The code block above synchronizes, ensuring safe access to
        // max_id_src and max_id_dst.
123
        if (num_bits == 0) {
124
//
125
126
127
128
129
130
          index_t max_id = 0;
          for (std::size_t i = 0; i < max_id_src.size(); i++) {
            max_id = std::max(max_id, static_cast<index_t>(max_id_src[i]));
            max_id = std::max(max_id, static_cast<index_t>(max_id_dst[i]));
          }
          num_bits = cuda::NumberOfBits(1ll + max_id);
131
132
133
134
//          num_bits = cuda::NumberOfBits(
//              1 + ::max(
//                     static_cast<scalar_t>(max_id_src),
//                     static_cast<scalar_t>(max_id_dst)));
135
136
137
        }

        // Sort the only_src tensor so that we can unique it later.
138
139
140
141
142
        std::vector<torch::Tensor> sorted_only_src;
        for (auto& only_src_i : only_src) {
          sorted_only_src.emplace_back(Sort<false>(
              only_src_i.data_ptr<index_t>(), only_src_i.size(0), num_bits));
        }
143

144
145
        std::vector<torch::Tensor> unique_only_src;
        std::vector<index_t*> unique_only_src_ptr;
146

147
148
149
150
151
152
153
154
        std::vector<cuda::CopyScalar<int64_t>> unique_only_src_size;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          // Compute the unique operation on the only_src tensor.
          unique_only_src.emplace_back(
              torch::empty(only_src[i].size(0), src_ids[i].options()));
          unique_only_src_ptr.emplace_back(
              unique_only_src[i].data_ptr<index_t>());
          unique_only_src_size.emplace_back(cuda::CopyScalar<int64_t>{});
155
          CUB_CALL(
156
157
158
159
160
161
162
163
              DeviceSelect::Unique, sorted_only_src[i].data_ptr<index_t>(),
              unique_only_src_ptr[i], unique_only_src_size[i].get(),
              only_src[i].size(0));
        }
        stream.synchronize();
        for (std::size_t i = 0; i < unique_only_src.size(); i++) {
          unique_only_src[i] = unique_only_src[i].slice(
              0, 0, static_cast<int64_t>(unique_only_src_size[i]));
164
165
        }

166
167
168
169
170
        std::vector<torch::Tensor> real_order;
        for (std::size_t i = 0; i < unique_dst_ids.size(); i++) {
          real_order.emplace_back(
              torch::cat({unique_dst_ids[i], unique_only_src[i]}));
        }
171
        // Sort here so that binary search can be used to lookup new_ids.
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        std::vector<torch::Tensor> sorted_order, new_ids;
        std::vector<index_t*> sorted_order_ptr;
        std::vector<int64_t*> new_ids_ptr;
        for (std::size_t i = 0; i < real_order.size(); i++) {
          auto [sorted_order_i, new_ids_i] = Sort(real_order[i], num_bits);
          sorted_order_ptr.emplace_back(sorted_order_i.data_ptr<index_t>());
          new_ids_ptr.emplace_back(new_ids_i.data_ptr<int64_t>());
          sorted_order.emplace_back(std::move(sorted_order_i));
          new_ids.emplace_back(std::move(new_ids_i));
        }
        // Holds the found locations of the src and dst ids in the
        // sorted_order. Later is used to lookup the new ids of the src_ids
        // and dst_ids tensors.
        std::vector<decltype(allocator.AllocateStorage<index_t>(0))>
            new_dst_ids_loc;
        for (std::size_t i = 0; i < sorted_order.size(); i++) {
          new_dst_ids_loc.emplace_back(
              allocator.AllocateStorage<index_t>(dst_ids[i].size(0)));
          THRUST_CALL(
              lower_bound, sorted_order_ptr[i],
              sorted_order_ptr[i] + sorted_order[i].size(0), dst_ids_ptr[i],
              dst_ids_ptr[i] + dst_ids[i].size(0), new_dst_ids_loc[i].get());
        }

        std::vector<cuda::CopyScalar<bool>> all_exist;
        at::cuda::CUDAEvent all_exist_event;
        bool should_record = false;
199
        // Check if unique_dst_ids includes all dst_ids.
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        for (std::size_t i = 0; i < dst_ids.size(); i++) {
          if (dst_ids[i].size(0) > 0) {
            thrust::counting_iterator<int64_t> iota(0);
            auto equal_it = thrust::make_transform_iterator(
                iota, EqualityFunc<index_t>{
                          sorted_order_ptr[i], new_dst_ids_loc[i].get(),
                          dst_ids_ptr[i]});
            all_exist.emplace_back(Min(equal_it, dst_ids[i].size(0)));
            should_record = true;
          } else {
            all_exist.emplace_back(cuda::CopyScalar<bool>{});
          }
        }
        if (should_record) all_exist_event.record();

        std::vector<decltype(allocator.AllocateStorage<index_t>(0))>
            new_src_ids_loc;
        for (std::size_t i = 0; i < sorted_order.size(); i++) {
          new_src_ids_loc.emplace_back(
              allocator.AllocateStorage<index_t>(src_ids[i].size(0)));
          THRUST_CALL(
              lower_bound, sorted_order_ptr[i],
              sorted_order_ptr[i] + sorted_order[i].size(0), src_ids_ptr[i],
              src_ids_ptr[i] + src_ids[i].size(0), new_src_ids_loc[i].get());
        }

        // Finally, lookup the new compact ids of the src and dst tensors
        // via gather operations.
        std::vector<torch::Tensor> new_src_ids;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          new_src_ids.emplace_back(torch::empty_like(src_ids[i]));
          THRUST_CALL(
              gather, new_src_ids_loc[i].get(),
              new_src_ids_loc[i].get() + src_ids[i].size(0),
              new_ids[i].data_ptr<int64_t>(),
              new_src_ids[i].data_ptr<index_t>());
        }
237
        // Perform check before we gather for the dst indices.
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        for (std::size_t i = 0; i < dst_ids.size(); i++) {
          if (dst_ids[i].size(0) > 0) {
            if (should_record) {
              all_exist_event.synchronize();
              should_record = false;
            }
            if (!static_cast<bool>(all_exist[i])) {
              throw std::out_of_range("Some ids not found.");
            }
          }
        }
        std::vector<torch::Tensor> new_dst_ids;
        for (std::size_t i = 0; i < dst_ids.size(); i++) {
          new_dst_ids.emplace_back(torch::empty_like(dst_ids[i]));
          THRUST_CALL(
              gather, new_dst_ids_loc[i].get(),
              new_dst_ids_loc[i].get() + dst_ids[i].size(0),
              new_ids[i].data_ptr<int64_t>(),
              new_dst_ids[i].data_ptr<index_t>());
        }
        std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
            results;
        for (std::size_t i = 0; i < src_ids.size(); i++) {
          results.emplace_back(
              std::move(real_order[i]), std::move(new_src_ids[i]),
              std::move(new_dst_ids[i]));
        }
        return results;
266
267
268
      }));
}

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
UniqueAndCompactBatched(
    const std::vector<torch::Tensor>& src_ids,
    const std::vector<torch::Tensor>& dst_ids,
    const std::vector<torch::Tensor>& unique_dst_ids, int num_bits) {
  auto dev_id = cuda::GetCurrentStream().device_index();
  static std::mutex mtx;
  static std::unordered_map<decltype(dev_id), int> compute_capability_cache;
  const auto compute_capability_major = [&] {
    std::lock_guard lock(mtx);
    auto it = compute_capability_cache.find(dev_id);
    if (it != compute_capability_cache.end()) {
      return it->second;
    } else {
      int major;
284
285
      CUDA_RUNTIME_CHECK(cudaDeviceGetAttribute(
          &major, cudaDevAttrComputeCapabilityMajor, dev_id));
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
      return compute_capability_cache[dev_id] = major;
    }
  }();
  if (compute_capability_major >= 7) {
    // Utilizes a hash table based implementation, the mapped id of a vertex
    // will be monotonically increasing as the first occurrence index of it in
    // torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic.
    return UniqueAndCompactBatchedHashMapBased(
        src_ids, dst_ids, unique_dst_ids);
  }
  // Utilizes a sort based algorithm, the mapped id of a vertex part of the
  // src_ids but not part of the unique_dst_ids will be monotonically increasing
  // as the actual vertex id increases. Thus, it is deterministic.
  return UniqueAndCompactBatchedSortBased(
      src_ids, dst_ids, unique_dst_ids, num_bits);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
    const torch::Tensor src_ids, const torch::Tensor dst_ids,
    const torch::Tensor unique_dst_ids, int num_bits) {
  return UniqueAndCompactBatched(
      {src_ids}, {dst_ids}, {unique_dst_ids}, num_bits)[0];
}

310
311
}  // namespace ops
}  // namespace graphbolt