neighbor_sampler.cu 17.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/index_select_impl.cu
 * @brief Index select operator implementation on CUDA.
 */
#include <c10/core/ScalarType.h>
#include <curand_kernel.h>
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <algorithm>
#include <array>
#include <cub/cub.cuh>
#include <limits>
#include <numeric>
#include <type_traits>

#include "../random.h"
#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

/**
 * @brief Fills the random_arr with random numbers and the edge_ids array with
 * original edge ids. When random_arr is sorted along with edge_ids, the first
 * fanout elements of each row gives us the sampled edges.
 */
template <
    typename float_t, typename indptr_t, typename indices_t, typename weights_t,
    typename edge_id_t>
__global__ void _ComputeRandoms(
    const int64_t num_edges, const indptr_t* const sliced_indptr,
    const indptr_t* const sub_indptr, const indices_t* const csr_rows,
43
    const weights_t* const sliced_weights, const indices_t* const indices,
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) {
  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride = gridDim.x * blockDim.x;
  curandStatePhilox4_32_10_t rng;
  const auto labor = indices != nullptr;

  if (!labor) {
    curand_init(random_seed, i, 0, &rng);
  }

  while (i < num_edges) {
    const auto row_position = csr_rows[i];
    const auto row_offset = i - sub_indptr[row_position];
    const auto in_idx = sliced_indptr[row_position] + row_offset;

    if (labor) {
      constexpr uint64_t kCurandSeed = 999961;
      curand_init(kCurandSeed, random_seed, indices[in_idx], &rng);
    }

    const auto rnd = curand_uniform(&rng);
65
66
    const auto prob =
        sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
67
68
69
70
71
72
73
74
75
76
77
    const auto exp_rnd = -__logf(rnd);
    const float_t adjusted_rnd = prob > 0
                                     ? static_cast<float_t>(exp_rnd / prob)
                                     : std::numeric_limits<float_t>::infinity();
    random_arr[i] = adjusted_rnd;
    edge_ids[i] = row_offset;

    i += stride;
  }
}

78
79
80
81
82
83
84
struct IsPositive {
  template <typename probs_t>
  __host__ __device__ auto operator()(probs_t x) {
    return x > 0;
  }
};

85
86
87
template <typename indptr_t>
struct MinInDegreeFanout {
  const indptr_t* in_degree;
88
89
  const int64_t* fanouts;
  size_t num_fanouts;
90
91
  __host__ __device__ auto operator()(int64_t i) {
    return static_cast<indptr_t>(
92
        min(static_cast<int64_t>(in_degree[i]), fanouts[i % num_fanouts]));
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
  indptr_t* indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t>
struct AddOffset {
  indptr_t offset;
  template <typename edge_id_t>
  __host__ __device__ indptr_t operator()(edge_id_t x) {
    return x + offset;
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFuncAddOffset {
  indptr_t* indptr;
  indptr_t* sliced_indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) {
    return thrust::transform_output_iterator{
        indices + indptr[i], AddOffset<indptr_t>{sliced_indptr[i]}};
  }
};

123
124
125
126
127
128
129
130
131
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc {
  indptr_t* indptr;
  in_degree_iterator_t in_degree;
  __host__ __device__ auto operator()(int64_t i) {
    return indptr[i] + in_degree[i];
  }
};

132
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
133
134
135
136
    torch::Tensor indptr, torch::Tensor indices,
    torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
    bool replace, bool layer, bool return_eids,
    torch::optional<torch::Tensor> type_per_edge,
137
138
139
140
141
142
    torch::optional<torch::Tensor> probs_or_mask) {
  TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
  // Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
  // are all resident on the GPU. If not, it is better to first extract them
  // before calling this function.
  auto allocator = cuda::GetAllocator();
143
144
  auto num_rows =
      nodes.has_value() ? nodes.value().size(0) : indptr.size(0) - 1;
145
146
147
148
149
150
151
152
153
154
155
156
  auto fanouts_pinned = torch::empty(
      fanouts.size(),
      c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
  auto fanouts_pinned_ptr = fanouts_pinned.data_ptr<int64_t>();
  for (size_t i = 0; i < fanouts.size(); i++) {
    fanouts_pinned_ptr[i] =
        fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits<int64_t>::max();
  }
  // Finally, copy the adjusted fanout values to the device memory.
  auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size());
  CUDA_CALL(cudaMemcpyAsync(
      fanouts_device.get(), fanouts_pinned_ptr,
157
158
      sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,
      cuda::GetCurrentStream()));
159
160
  auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
  auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
161
  auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
162
163
164
165
166
167
168
169
170
  auto max_in_degree = torch::empty(
      1,
      c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
  AT_DISPATCH_INDEX_TYPES(
      indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] {
        CUB_CALL(
            DeviceReduce::Max, in_degree.data_ptr<index_t>(),
            max_in_degree.data_ptr<index_t>(), num_rows);
      }));
171
172
173
174
  // Protect access to max_in_degree with a CUDAEvent
  at::cuda::CUDAEvent max_in_degree_event;
  max_in_degree_event.record();
  torch::optional<int64_t> num_edges;
175
  torch::Tensor sub_indptr;
176
177
178
179
  if (!nodes.has_value()) {
    num_edges = indices.size(0);
    sub_indptr = indptr;
  }
180
181
  torch::optional<torch::Tensor> sliced_probs_or_mask;
  if (probs_or_mask.has_value()) {
182
183
184
185
186
187
188
189
190
191
    if (nodes.has_value()) {
      torch::Tensor sliced_probs_or_mask_tensor;
      std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
          in_degree, sliced_indptr, probs_or_mask.value(), nodes.value(),
          indptr.size(0) - 2, num_edges);
      sliced_probs_or_mask = sliced_probs_or_mask_tensor;
      num_edges = sliced_probs_or_mask_tensor.size(0);
    } else {
      sliced_probs_or_mask = probs_or_mask;
    }
192
  }
193
194
  if (fanouts.size() > 1) {
    torch::Tensor sliced_type_per_edge;
195
196
197
198
199
200
201
    if (nodes.has_value()) {
      std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
          in_degree, sliced_indptr, type_per_edge.value(), nodes.value(),
          indptr.size(0) - 2, num_edges);
    } else {
      sliced_type_per_edge = type_per_edge.value();
    }
202
203
204
    std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(
        sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());
    num_rows = sliced_indptr.size(0);
205
    num_edges = sliced_type_per_edge.size(0);
206
207
  }
  // If sub_indptr was not computed in the two code blocks above:
208
  if (nodes.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
209
    sub_indptr = ExclusiveCumSum(in_degree);
210
  }
211
  auto coo_rows = ExpandIndptrImpl(
212
213
      sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
  num_edges = coo_rows.size(0);
214
215
  const auto random_seed = RandomEngine::ThreadLocal()->RandInt(
      static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
216
  auto output_indptr = torch::empty_like(sub_indptr);
217
218
  torch::Tensor picked_eids;
  torch::Tensor output_indices;
219
  torch::optional<torch::Tensor> output_type_per_edge;
220

221
  AT_DISPATCH_INDEX_TYPES(
222
      indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
223
        using indptr_t = index_t;
224
225
226
227
228
229
230
231
        if (probs_or_mask.has_value()) {  // Count nonzero probs into in_degree.
          GRAPHBOLT_DISPATCH_ALL_TYPES(
              probs_or_mask.value().scalar_type(),
              "SampleNeighborsPositiveProbs", ([&] {
                using probs_t = scalar_t;
                auto is_nonzero = thrust::make_transform_iterator(
                    sliced_probs_or_mask.value().data_ptr<probs_t>(),
                    IsPositive{});
232
233
                CUB_CALL(
                    DeviceSegmentedReduce::Sum, is_nonzero,
234
235
                    in_degree.data_ptr<indptr_t>(), num_rows,
                    sub_indptr.data_ptr<indptr_t>(),
236
                    sub_indptr.data_ptr<indptr_t>() + 1);
237
238
              }));
        }
239
240
241
        thrust::counting_iterator<int64_t> iota(0);
        auto sampled_degree = thrust::make_transform_iterator(
            iota, MinInDegreeFanout<indptr_t>{
242
243
                      in_degree.data_ptr<indptr_t>(), fanouts_device.get(),
                      fanouts.size()});
244

245
246
247
248
        // Compute output_indptr.
        CUB_CALL(
            DeviceScan::ExclusiveSum, sampled_degree,
            output_indptr.data_ptr<indptr_t>(), num_rows + 1);
249
250
251
252

        auto num_sampled_edges =
            cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};

253
254
255
        // Find the smallest integer type to store the edge id offsets. We synch
        // the CUDAEvent so that the access is safe.
        max_in_degree_event.synchronize();
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        const int num_bits =
            cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
        std::array<int, 4> type_bits = {8, 16, 32, 64};
        const auto type_index =
            std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
            type_bits.begin();
        std::array<torch::ScalarType, 5> types = {
            torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
            torch::kLong};
        auto edge_id_dtype = types[type_index];
        AT_DISPATCH_INTEGRAL_TYPES(
            edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
              using edge_id_t = std::make_unsigned_t<scalar_t>;
              TORCH_CHECK(
                  num_bits <= sizeof(edge_id_t) * 8,
                  "Selected edge_id_t must be capable of storing edge_ids.");
              // Using bfloat16 for random numbers works just as reliably as
              // float32 and provides around %30 percent speedup.
              using rnd_t = nv_bfloat16;
275
276
277
278
              auto randoms =
                  allocator.AllocateStorage<rnd_t>(num_edges.value());
              auto randoms_sorted =
                  allocator.AllocateStorage<rnd_t>(num_edges.value());
279
              auto edge_id_segments =
280
                  allocator.AllocateStorage<edge_id_t>(num_edges.value());
281
              auto sorted_edge_id_segments =
282
                  allocator.AllocateStorage<edge_id_t>(num_edges.value());
283
              AT_DISPATCH_INDEX_TYPES(
284
                  indices.scalar_type(), "SampleNeighborsIndices", ([&] {
285
                    using indices_t = index_t;
286
287
288
289
290
291
292
293
294
                    auto probs_or_mask_scalar_type = torch::kFloat32;
                    if (probs_or_mask.has_value()) {
                      probs_or_mask_scalar_type =
                          probs_or_mask.value().scalar_type();
                    }
                    GRAPHBOLT_DISPATCH_ALL_TYPES(
                        probs_or_mask_scalar_type, "SampleNeighborsProbs",
                        ([&] {
                          using probs_t = scalar_t;
295
296
297
298
                          probs_t* sliced_probs_ptr = nullptr;
                          if (sliced_probs_or_mask.has_value()) {
                            sliced_probs_ptr = sliced_probs_or_mask.value()
                                                   .data_ptr<probs_t>();
299
300
301
302
303
                          }
                          const indices_t* indices_ptr =
                              layer ? indices.data_ptr<indices_t>() : nullptr;
                          const dim3 block(BLOCK_SIZE);
                          const dim3 grid(
304
305
                              (num_edges.value() + BLOCK_SIZE - 1) /
                              BLOCK_SIZE);
306
307
                          // Compute row and random number pairs.
                          CUDA_KERNEL_CALL(
308
309
                              _ComputeRandoms, grid, block, 0,
                              num_edges.value(),
310
                              sliced_indptr.data_ptr<indptr_t>(),
311
                              sub_indptr.data_ptr<indptr_t>(),
312
                              coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
313
314
315
316
317
318
319
320
                              indices_ptr, random_seed, randoms.get(),
                              edge_id_segments.get());
                        }));
                  }));

              // Sort the random numbers along with edge ids, after
              // sorting the first fanout elements of each row will
              // give us the sampled edges.
321
322
              CUB_CALL(
                  DeviceSegmentedSort::SortPairs, randoms.get(),
323
                  randoms_sorted.get(), edge_id_segments.get(),
324
                  sorted_edge_id_segments.get(), num_edges.value(), num_rows,
325
                  sub_indptr.data_ptr<indptr_t>(),
326
                  sub_indptr.data_ptr<indptr_t>() + 1);
327
328
329

              picked_eids = torch::empty(
                  static_cast<indptr_t>(num_sampled_edges),
330
                  sub_indptr.options());
331

332
333
334
335
336
337
338
339
340
              // Need to sort the sampled edges only when fanouts.size() == 1
              // since multiple fanout sampling case is automatically going to
              // be sorted.
              if (type_per_edge && fanouts.size() == 1) {
                // Ensuring sort result still ends up in sorted_edge_id_segments
                std::swap(edge_id_segments, sorted_edge_id_segments);
                auto sampled_segment_end_it = thrust::make_transform_iterator(
                    iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
                              sub_indptr.data_ptr<indptr_t>(), sampled_degree});
341
342
                CUB_CALL(
                    DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
343
344
                    sorted_edge_id_segments.get(), picked_eids.size(0),
                    num_rows, sub_indptr.data_ptr<indptr_t>(),
345
                    sampled_segment_end_it);
346
347
              }

348
349
350
351
352
353
354
355
356
357
358
359
360
361
              auto input_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFunc<indptr_t, edge_id_t>{
                            sub_indptr.data_ptr<indptr_t>(),
                            sorted_edge_id_segments.get()});
              auto output_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
                            output_indptr.data_ptr<indptr_t>(),
                            sliced_indptr.data_ptr<indptr_t>(),
                            picked_eids.data_ptr<indptr_t>()});
              constexpr int64_t max_copy_at_once =
                  std::numeric_limits<int32_t>::max();

              // Copy the sampled edge ids into picked_eids tensor.
              for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
362
363
                CUB_CALL(
                    DeviceCopy::Batched, input_buffer_it + i,
364
                    output_buffer_it + i, sampled_degree + i,
365
                    std::min(num_rows - i, max_copy_at_once));
366
367
368
369
370
371
372
373
              }
            }));

        output_indices = torch::empty(
            picked_eids.size(0),
            picked_eids.options().dtype(indices.scalar_type()));

        // Compute: output_indices = indices.gather(0, picked_eids);
374
        AT_DISPATCH_INDEX_TYPES(
375
            indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
376
              using indices_t = index_t;
377
378
              THRUST_CALL(
                  gather, picked_eids.data_ptr<indptr_t>(),
379
380
381
382
                  picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                  indices.data_ptr<indices_t>(),
                  output_indices.data_ptr<indices_t>());
            }));
383
384
385
386
387
388
389
390
391
392
393
394

        if (type_per_edge) {
          // output_type_per_edge = type_per_edge.gather(0, picked_eids);
          // The commented out torch equivalent above does not work when
          // type_per_edge is on pinned memory. That is why, we have to
          // reimplement it, similar to the indices gather operation above.
          auto types = type_per_edge.value();
          output_type_per_edge = torch::empty(
              picked_eids.size(0),
              picked_eids.options().dtype(types.scalar_type()));
          AT_DISPATCH_INTEGRAL_TYPES(
              types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
395
396
                THRUST_CALL(
                    gather, picked_eids.data_ptr<indptr_t>(),
397
398
399
400
401
                    picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                    types.data_ptr<scalar_t>(),
                    output_type_per_edge.value().data_ptr<scalar_t>());
              }));
        }
402
403
      }));

404
405
406
  // Convert output_indptr back to homo by discarding intermediate offsets.
  output_indptr =
      output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
407
408
  torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
  if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
409
410
411
  if (!nodes.has_value()) {
    nodes = torch::arange(indptr.size(0) - 1, indices.options());
  }
412
413

  return c10::make_intrusive<sampling::FusedSampledSubgraph>(
414
      output_indptr, output_indices, nodes.value(), torch::nullopt,
415
      subgraph_reverse_edge_ids, output_type_per_edge);
416
417
418
419
}

}  //  namespace ops
}  //  namespace graphbolt