neighbor_sampler.hip 17.4 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
3
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "hip/hip_bf16.h"
4
5
6
7
8
9
10
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/index_select_impl.cu
 * @brief Index select operator implementation on CUDA.
 */
#include <c10/core/ScalarType.h>
sangwzh's avatar
sangwzh committed
11
#include <hiprand/hiprand_kernel.h>
12
13
14
15
16
17
18
19
20
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <algorithm>
#include <array>
sangwzh's avatar
sangwzh committed
21
#include <hipcub/hipcub.hpp>
22
23
24
25
26
#include <limits>
#include <numeric>
#include <type_traits>

#include "../random.h"
sangwzh's avatar
sangwzh committed
27
28
#include "common.h"
#include "utils.h"
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

/**
 * @brief Fills the random_arr with random numbers and the edge_ids array with
 * original edge ids. When random_arr is sorted along with edge_ids, the first
 * fanout elements of each row gives us the sampled edges.
 */
template <
    typename float_t, typename indptr_t, typename indices_t, typename weights_t,
    typename edge_id_t>
__global__ void _ComputeRandoms(
    const int64_t num_edges, const indptr_t* const sliced_indptr,
    const indptr_t* const sub_indptr, const indices_t* const csr_rows,
46
    const weights_t* const sliced_weights, const indices_t* const indices,
47
48
49
    const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) {
  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride = gridDim.x * blockDim.x;
sangwzh's avatar
sangwzh committed
50
  hiprandStatePhilox4_32_10_t rng;
51
52
53
  const auto labor = indices != nullptr;

  if (!labor) {
sangwzh's avatar
sangwzh committed
54
    hiprand_init(random_seed, i, 0, &rng);
55
56
57
58
59
60
61
62
63
  }

  while (i < num_edges) {
    const auto row_position = csr_rows[i];
    const auto row_offset = i - sub_indptr[row_position];
    const auto in_idx = sliced_indptr[row_position] + row_offset;

    if (labor) {
      constexpr uint64_t kCurandSeed = 999961;
sangwzh's avatar
sangwzh committed
64
      hiprand_init(kCurandSeed, random_seed, indices[in_idx], &rng);
65
66
    }

sangwzh's avatar
sangwzh committed
67
    const auto rnd = hiprand_uniform(&rng);
68
69
    const auto prob =
        sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
70
71
72
73
74
75
76
77
78
79
80
    const auto exp_rnd = -__logf(rnd);
    const float_t adjusted_rnd = prob > 0
                                     ? static_cast<float_t>(exp_rnd / prob)
                                     : std::numeric_limits<float_t>::infinity();
    random_arr[i] = adjusted_rnd;
    edge_ids[i] = row_offset;

    i += stride;
  }
}

81
82
83
84
85
86
87
struct IsPositive {
  template <typename probs_t>
  __host__ __device__ auto operator()(probs_t x) {
    return x > 0;
  }
};

88
89
90
template <typename indptr_t>
struct MinInDegreeFanout {
  const indptr_t* in_degree;
91
92
  const int64_t* fanouts;
  size_t num_fanouts;
93
94
  __host__ __device__ auto operator()(int64_t i) {
    return static_cast<indptr_t>(
95
        min(static_cast<int64_t>(in_degree[i]), fanouts[i % num_fanouts]));
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
  indptr_t* indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t>
struct AddOffset {
  indptr_t offset;
  template <typename edge_id_t>
  __host__ __device__ indptr_t operator()(edge_id_t x) {
    return x + offset;
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFuncAddOffset {
  indptr_t* indptr;
  indptr_t* sliced_indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) {
    return thrust::transform_output_iterator{
        indices + indptr[i], AddOffset<indptr_t>{sliced_indptr[i]}};
  }
};

126
127
128
129
130
131
132
133
134
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc {
  indptr_t* indptr;
  in_degree_iterator_t in_degree;
  __host__ __device__ auto operator()(int64_t i) {
    return indptr[i] + in_degree[i];
  }
};

135
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
136
137
138
139
    torch::Tensor indptr, torch::Tensor indices,
    torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
    bool replace, bool layer, bool return_eids,
    torch::optional<torch::Tensor> type_per_edge,
140
141
142
143
144
145
    torch::optional<torch::Tensor> probs_or_mask) {
  TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
  // Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
  // are all resident on the GPU. If not, it is better to first extract them
  // before calling this function.
  auto allocator = cuda::GetAllocator();
146
147
  auto num_rows =
      nodes.has_value() ? nodes.value().size(0) : indptr.size(0) - 1;
148
149
150
151
152
153
154
155
156
157
  auto fanouts_pinned = torch::empty(
      fanouts.size(),
      c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
  auto fanouts_pinned_ptr = fanouts_pinned.data_ptr<int64_t>();
  for (size_t i = 0; i < fanouts.size(); i++) {
    fanouts_pinned_ptr[i] =
        fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits<int64_t>::max();
  }
  // Finally, copy the adjusted fanout values to the device memory.
  auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size());
sangwzh's avatar
sangwzh committed
158
  CUDA_CALL(hipMemcpyAsync(
159
      fanouts_device.get(), fanouts_pinned_ptr,
sangwzh's avatar
sangwzh committed
160
      sizeof(int64_t) * fanouts.size(), hipMemcpyHostToDevice,
161
      cuda::GetCurrentStream()));
162
163
  auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
  auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
164
  auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
165
166
167
168
169
170
171
172
173
  auto max_in_degree = torch::empty(
      1,
      c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
  AT_DISPATCH_INDEX_TYPES(
      indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] {
        CUB_CALL(
            DeviceReduce::Max, in_degree.data_ptr<index_t>(),
            max_in_degree.data_ptr<index_t>(), num_rows);
      }));
174
175
176
177
  // Protect access to max_in_degree with a CUDAEvent
  at::cuda::CUDAEvent max_in_degree_event;
  max_in_degree_event.record();
  torch::optional<int64_t> num_edges;
178
  torch::Tensor sub_indptr;
179
180
181
182
  if (!nodes.has_value()) {
    num_edges = indices.size(0);
    sub_indptr = indptr;
  }
183
184
  torch::optional<torch::Tensor> sliced_probs_or_mask;
  if (probs_or_mask.has_value()) {
185
186
187
188
189
190
191
192
193
194
    if (nodes.has_value()) {
      torch::Tensor sliced_probs_or_mask_tensor;
      std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
          in_degree, sliced_indptr, probs_or_mask.value(), nodes.value(),
          indptr.size(0) - 2, num_edges);
      sliced_probs_or_mask = sliced_probs_or_mask_tensor;
      num_edges = sliced_probs_or_mask_tensor.size(0);
    } else {
      sliced_probs_or_mask = probs_or_mask;
    }
195
  }
196
197
  if (fanouts.size() > 1) {
    torch::Tensor sliced_type_per_edge;
198
199
200
201
202
203
204
    if (nodes.has_value()) {
      std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
          in_degree, sliced_indptr, type_per_edge.value(), nodes.value(),
          indptr.size(0) - 2, num_edges);
    } else {
      sliced_type_per_edge = type_per_edge.value();
    }
205
206
207
    std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(
        sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());
    num_rows = sliced_indptr.size(0);
208
    num_edges = sliced_type_per_edge.size(0);
209
210
  }
  // If sub_indptr was not computed in the two code blocks above:
211
  if (nodes.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
212
    sub_indptr = ExclusiveCumSum(in_degree);
213
  }
214
  auto coo_rows = ExpandIndptrImpl(
215
216
      sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
  num_edges = coo_rows.size(0);
217
218
  const auto random_seed = RandomEngine::ThreadLocal()->RandInt(
      static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
219
  auto output_indptr = torch::empty_like(sub_indptr);
220
221
  torch::Tensor picked_eids;
  torch::Tensor output_indices;
222
  torch::optional<torch::Tensor> output_type_per_edge;
223

224
  AT_DISPATCH_INDEX_TYPES(
225
      indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
226
        using indptr_t = index_t;
227
228
229
230
231
232
233
234
        if (probs_or_mask.has_value()) {  // Count nonzero probs into in_degree.
          GRAPHBOLT_DISPATCH_ALL_TYPES(
              probs_or_mask.value().scalar_type(),
              "SampleNeighborsPositiveProbs", ([&] {
                using probs_t = scalar_t;
                auto is_nonzero = thrust::make_transform_iterator(
                    sliced_probs_or_mask.value().data_ptr<probs_t>(),
                    IsPositive{});
235
236
                CUB_CALL(
                    DeviceSegmentedReduce::Sum, is_nonzero,
237
238
                    in_degree.data_ptr<indptr_t>(), num_rows,
                    sub_indptr.data_ptr<indptr_t>(),
239
                    sub_indptr.data_ptr<indptr_t>() + 1);
240
241
              }));
        }
242
243
244
        thrust::counting_iterator<int64_t> iota(0);
        auto sampled_degree = thrust::make_transform_iterator(
            iota, MinInDegreeFanout<indptr_t>{
245
246
                      in_degree.data_ptr<indptr_t>(), fanouts_device.get(),
                      fanouts.size()});
247

248
249
250
251
        // Compute output_indptr.
        CUB_CALL(
            DeviceScan::ExclusiveSum, sampled_degree,
            output_indptr.data_ptr<indptr_t>(), num_rows + 1);
252
253
254
255

        auto num_sampled_edges =
            cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};

256
257
258
        // Find the smallest integer type to store the edge id offsets. We synch
        // the CUDAEvent so that the access is safe.
        max_in_degree_event.synchronize();
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        const int num_bits =
            cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
        std::array<int, 4> type_bits = {8, 16, 32, 64};
        const auto type_index =
            std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
            type_bits.begin();
        std::array<torch::ScalarType, 5> types = {
            torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
            torch::kLong};
        auto edge_id_dtype = types[type_index];
        AT_DISPATCH_INTEGRAL_TYPES(
            edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
              using edge_id_t = std::make_unsigned_t<scalar_t>;
              TORCH_CHECK(
                  num_bits <= sizeof(edge_id_t) * 8,
                  "Selected edge_id_t must be capable of storing edge_ids.");
              // Using bfloat16 for random numbers works just as reliably as
              // float32 and provides around %30 percent speedup.
sangwzh's avatar
sangwzh committed
277
              using rnd_t = __hip_bfloat16;
278
279
280
281
              auto randoms =
                  allocator.AllocateStorage<rnd_t>(num_edges.value());
              auto randoms_sorted =
                  allocator.AllocateStorage<rnd_t>(num_edges.value());
282
              auto edge_id_segments =
283
                  allocator.AllocateStorage<edge_id_t>(num_edges.value());
284
              auto sorted_edge_id_segments =
285
                  allocator.AllocateStorage<edge_id_t>(num_edges.value());
286
              AT_DISPATCH_INDEX_TYPES(
287
                  indices.scalar_type(), "SampleNeighborsIndices", ([&] {
288
                    using indices_t = index_t;
289
290
291
292
293
294
295
296
297
                    auto probs_or_mask_scalar_type = torch::kFloat32;
                    if (probs_or_mask.has_value()) {
                      probs_or_mask_scalar_type =
                          probs_or_mask.value().scalar_type();
                    }
                    GRAPHBOLT_DISPATCH_ALL_TYPES(
                        probs_or_mask_scalar_type, "SampleNeighborsProbs",
                        ([&] {
                          using probs_t = scalar_t;
298
299
300
301
                          probs_t* sliced_probs_ptr = nullptr;
                          if (sliced_probs_or_mask.has_value()) {
                            sliced_probs_ptr = sliced_probs_or_mask.value()
                                                   .data_ptr<probs_t>();
302
303
304
305
306
                          }
                          const indices_t* indices_ptr =
                              layer ? indices.data_ptr<indices_t>() : nullptr;
                          const dim3 block(BLOCK_SIZE);
                          const dim3 grid(
307
308
                              (num_edges.value() + BLOCK_SIZE - 1) /
                              BLOCK_SIZE);
309
310
                          // Compute row and random number pairs.
                          CUDA_KERNEL_CALL(
311
312
                              _ComputeRandoms, grid, block, 0,
                              num_edges.value(),
313
                              sliced_indptr.data_ptr<indptr_t>(),
314
                              sub_indptr.data_ptr<indptr_t>(),
315
                              coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
316
317
318
319
320
321
322
323
                              indices_ptr, random_seed, randoms.get(),
                              edge_id_segments.get());
                        }));
                  }));

              // Sort the random numbers along with edge ids, after
              // sorting the first fanout elements of each row will
              // give us the sampled edges.
324
325
              CUB_CALL(
                  DeviceSegmentedSort::SortPairs, randoms.get(),
326
                  randoms_sorted.get(), edge_id_segments.get(),
327
                  sorted_edge_id_segments.get(), num_edges.value(), num_rows,
328
                  sub_indptr.data_ptr<indptr_t>(),
329
                  sub_indptr.data_ptr<indptr_t>() + 1);
330
331
332

              picked_eids = torch::empty(
                  static_cast<indptr_t>(num_sampled_edges),
333
                  sub_indptr.options());
334

335
336
337
338
339
340
341
342
343
              // Need to sort the sampled edges only when fanouts.size() == 1
              // since multiple fanout sampling case is automatically going to
              // be sorted.
              if (type_per_edge && fanouts.size() == 1) {
                // Ensuring sort result still ends up in sorted_edge_id_segments
                std::swap(edge_id_segments, sorted_edge_id_segments);
                auto sampled_segment_end_it = thrust::make_transform_iterator(
                    iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
                              sub_indptr.data_ptr<indptr_t>(), sampled_degree});
344
345
                CUB_CALL(
                    DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
346
347
                    sorted_edge_id_segments.get(), picked_eids.size(0),
                    num_rows, sub_indptr.data_ptr<indptr_t>(),
348
                    sampled_segment_end_it);
349
350
              }

351
352
353
354
355
356
357
358
359
360
361
362
363
364
              auto input_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFunc<indptr_t, edge_id_t>{
                            sub_indptr.data_ptr<indptr_t>(),
                            sorted_edge_id_segments.get()});
              auto output_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
                            output_indptr.data_ptr<indptr_t>(),
                            sliced_indptr.data_ptr<indptr_t>(),
                            picked_eids.data_ptr<indptr_t>()});
              constexpr int64_t max_copy_at_once =
                  std::numeric_limits<int32_t>::max();

              // Copy the sampled edge ids into picked_eids tensor.
              for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
365
366
                CUB_CALL(
                    DeviceCopy::Batched, input_buffer_it + i,
367
                    output_buffer_it + i, sampled_degree + i,
sangwzh's avatar
sangwzh committed
368
                    ::min(num_rows - i, max_copy_at_once));
369
370
371
372
373
374
375
376
              }
            }));

        output_indices = torch::empty(
            picked_eids.size(0),
            picked_eids.options().dtype(indices.scalar_type()));

        // Compute: output_indices = indices.gather(0, picked_eids);
377
        AT_DISPATCH_INDEX_TYPES(
378
            indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
379
              using indices_t = index_t;
380
381
              THRUST_CALL(
                  gather, picked_eids.data_ptr<indptr_t>(),
382
383
384
385
                  picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                  indices.data_ptr<indices_t>(),
                  output_indices.data_ptr<indices_t>());
            }));
386
387
388
389
390
391
392
393
394
395
396
397

        if (type_per_edge) {
          // output_type_per_edge = type_per_edge.gather(0, picked_eids);
          // The commented out torch equivalent above does not work when
          // type_per_edge is on pinned memory. That is why, we have to
          // reimplement it, similar to the indices gather operation above.
          auto types = type_per_edge.value();
          output_type_per_edge = torch::empty(
              picked_eids.size(0),
              picked_eids.options().dtype(types.scalar_type()));
          AT_DISPATCH_INTEGRAL_TYPES(
              types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
398
399
                THRUST_CALL(
                    gather, picked_eids.data_ptr<indptr_t>(),
400
401
402
403
404
                    picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                    types.data_ptr<scalar_t>(),
                    output_type_per_edge.value().data_ptr<scalar_t>());
              }));
        }
405
406
      }));

407
408
409
  // Convert output_indptr back to homo by discarding intermediate offsets.
  output_indptr =
      output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
410
411
  torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
  if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
412
413
414
  if (!nodes.has_value()) {
    nodes = torch::arange(indptr.size(0) - 1, indices.options());
  }
415
416

  return c10::make_intrusive<sampling::FusedSampledSubgraph>(
417
      output_indptr, output_indices, nodes.value(), torch::nullopt,
418
      subgraph_reverse_edge_ids, output_type_per_edge);
419
420
421
422
}

}  //  namespace ops
}  //  namespace graphbolt