neighbor_sampler.cu 17.1 KB
Newer Older
1
2
3
4
5
6
7
8
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/index_select_impl.cu
 * @brief Index select operator implementation on CUDA.
 */
#include <c10/core/ScalarType.h>
#include <curand_kernel.h>
9
#include <graphbolt/continuous_seed.h>
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <algorithm>
#include <array>
#include <cub/cub.cuh>
#include <limits>
#include <numeric>
#include <type_traits>

#include "../random.h"
#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

/**
 * @brief Fills the random_arr with random numbers and the edge_ids array with
 * original edge ids. When random_arr is sorted along with edge_ids, the first
 * fanout elements of each row gives us the sampled edges.
 */
template <
    typename float_t, typename indptr_t, typename indices_t, typename weights_t,
    typename edge_id_t>
__global__ void _ComputeRandoms(
    const int64_t num_edges, const indptr_t* const sliced_indptr,
    const indptr_t* const sub_indptr, const indices_t* const csr_rows,
44
    const weights_t* const sliced_weights, const indices_t* const indices,
45
46
    const continuous_seed random_seed, float_t* random_arr,
    edge_id_t* edge_ids) {
47
48
49
50
51
52
53
54
  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride = gridDim.x * blockDim.x;
  const auto labor = indices != nullptr;

  while (i < num_edges) {
    const auto row_position = csr_rows[i];
    const auto row_offset = i - sub_indptr[row_position];
    const auto in_idx = sliced_indptr[row_position] + row_offset;
55
    const auto rnd = random_seed.uniform(labor ? indices[in_idx] : i);
56
57
    const auto prob =
        sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
58
59
60
61
62
63
64
65
66
67
68
    const auto exp_rnd = -__logf(rnd);
    const float_t adjusted_rnd = prob > 0
                                     ? static_cast<float_t>(exp_rnd / prob)
                                     : std::numeric_limits<float_t>::infinity();
    random_arr[i] = adjusted_rnd;
    edge_ids[i] = row_offset;

    i += stride;
  }
}

69
70
71
72
73
74
75
struct IsPositive {
  template <typename probs_t>
  __host__ __device__ auto operator()(probs_t x) {
    return x > 0;
  }
};

76
77
78
template <typename indptr_t>
struct MinInDegreeFanout {
  const indptr_t* in_degree;
79
80
  const int64_t* fanouts;
  size_t num_fanouts;
81
82
  __host__ __device__ auto operator()(int64_t i) {
    return static_cast<indptr_t>(
83
        min(static_cast<int64_t>(in_degree[i]), fanouts[i % num_fanouts]));
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
  indptr_t* indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t>
struct AddOffset {
  indptr_t offset;
  template <typename edge_id_t>
  __host__ __device__ indptr_t operator()(edge_id_t x) {
    return x + offset;
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFuncAddOffset {
  indptr_t* indptr;
  indptr_t* sliced_indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) {
    return thrust::transform_output_iterator{
        indices + indptr[i], AddOffset<indptr_t>{sliced_indptr[i]}};
  }
};

114
115
116
117
118
119
120
121
122
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc {
  indptr_t* indptr;
  in_degree_iterator_t in_degree;
  __host__ __device__ auto operator()(int64_t i) {
    return indptr[i] + in_degree[i];
  }
};

123
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
124
125
126
127
    torch::Tensor indptr, torch::Tensor indices,
    torch::optional<torch::Tensor> nodes, const std::vector<int64_t>& fanouts,
    bool replace, bool layer, bool return_eids,
    torch::optional<torch::Tensor> type_per_edge,
128
129
130
131
132
133
    torch::optional<torch::Tensor> probs_or_mask) {
  TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
  // Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
  // are all resident on the GPU. If not, it is better to first extract them
  // before calling this function.
  auto allocator = cuda::GetAllocator();
134
135
  auto num_rows =
      nodes.has_value() ? nodes.value().size(0) : indptr.size(0) - 1;
136
137
138
139
140
141
142
143
144
145
146
147
  auto fanouts_pinned = torch::empty(
      fanouts.size(),
      c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
  auto fanouts_pinned_ptr = fanouts_pinned.data_ptr<int64_t>();
  for (size_t i = 0; i < fanouts.size(); i++) {
    fanouts_pinned_ptr[i] =
        fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits<int64_t>::max();
  }
  // Finally, copy the adjusted fanout values to the device memory.
  auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size());
  CUDA_CALL(cudaMemcpyAsync(
      fanouts_device.get(), fanouts_pinned_ptr,
148
149
      sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,
      cuda::GetCurrentStream()));
150
151
  auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
  auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
152
  auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
153
154
155
156
157
158
159
160
161
  auto max_in_degree = torch::empty(
      1,
      c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
  AT_DISPATCH_INDEX_TYPES(
      indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] {
        CUB_CALL(
            DeviceReduce::Max, in_degree.data_ptr<index_t>(),
            max_in_degree.data_ptr<index_t>(), num_rows);
      }));
162
163
164
165
  // Protect access to max_in_degree with a CUDAEvent
  at::cuda::CUDAEvent max_in_degree_event;
  max_in_degree_event.record();
  torch::optional<int64_t> num_edges;
166
  torch::Tensor sub_indptr;
167
168
169
170
  if (!nodes.has_value()) {
    num_edges = indices.size(0);
    sub_indptr = indptr;
  }
171
172
  torch::optional<torch::Tensor> sliced_probs_or_mask;
  if (probs_or_mask.has_value()) {
173
174
175
176
177
178
179
180
181
182
    if (nodes.has_value()) {
      torch::Tensor sliced_probs_or_mask_tensor;
      std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
          in_degree, sliced_indptr, probs_or_mask.value(), nodes.value(),
          indptr.size(0) - 2, num_edges);
      sliced_probs_or_mask = sliced_probs_or_mask_tensor;
      num_edges = sliced_probs_or_mask_tensor.size(0);
    } else {
      sliced_probs_or_mask = probs_or_mask;
    }
183
  }
184
185
  if (fanouts.size() > 1) {
    torch::Tensor sliced_type_per_edge;
186
187
188
189
190
191
192
    if (nodes.has_value()) {
      std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
          in_degree, sliced_indptr, type_per_edge.value(), nodes.value(),
          indptr.size(0) - 2, num_edges);
    } else {
      sliced_type_per_edge = type_per_edge.value();
    }
193
194
195
    std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(
        sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());
    num_rows = sliced_indptr.size(0);
196
    num_edges = sliced_type_per_edge.size(0);
197
198
  }
  // If sub_indptr was not computed in the two code blocks above:
199
  if (nodes.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
200
    sub_indptr = ExclusiveCumSum(in_degree);
201
  }
202
  auto coo_rows = ExpandIndptrImpl(
203
204
      sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
  num_edges = coo_rows.size(0);
205
206
  const continuous_seed random_seed(RandomEngine::ThreadLocal()->RandInt(
      static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()));
207
  auto output_indptr = torch::empty_like(sub_indptr);
208
209
  torch::Tensor picked_eids;
  torch::Tensor output_indices;
210
  torch::optional<torch::Tensor> output_type_per_edge;
211

212
  AT_DISPATCH_INDEX_TYPES(
213
      indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
214
        using indptr_t = index_t;
215
216
217
218
219
220
221
222
        if (probs_or_mask.has_value()) {  // Count nonzero probs into in_degree.
          GRAPHBOLT_DISPATCH_ALL_TYPES(
              probs_or_mask.value().scalar_type(),
              "SampleNeighborsPositiveProbs", ([&] {
                using probs_t = scalar_t;
                auto is_nonzero = thrust::make_transform_iterator(
                    sliced_probs_or_mask.value().data_ptr<probs_t>(),
                    IsPositive{});
223
224
                CUB_CALL(
                    DeviceSegmentedReduce::Sum, is_nonzero,
225
226
                    in_degree.data_ptr<indptr_t>(), num_rows,
                    sub_indptr.data_ptr<indptr_t>(),
227
                    sub_indptr.data_ptr<indptr_t>() + 1);
228
229
              }));
        }
230
231
232
        thrust::counting_iterator<int64_t> iota(0);
        auto sampled_degree = thrust::make_transform_iterator(
            iota, MinInDegreeFanout<indptr_t>{
233
234
                      in_degree.data_ptr<indptr_t>(), fanouts_device.get(),
                      fanouts.size()});
235

236
237
238
239
        // Compute output_indptr.
        CUB_CALL(
            DeviceScan::ExclusiveSum, sampled_degree,
            output_indptr.data_ptr<indptr_t>(), num_rows + 1);
240
241
242
243

        auto num_sampled_edges =
            cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};

244
245
246
        // Find the smallest integer type to store the edge id offsets. We synch
        // the CUDAEvent so that the access is safe.
        max_in_degree_event.synchronize();
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        const int num_bits =
            cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
        std::array<int, 4> type_bits = {8, 16, 32, 64};
        const auto type_index =
            std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
            type_bits.begin();
        std::array<torch::ScalarType, 5> types = {
            torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
            torch::kLong};
        auto edge_id_dtype = types[type_index];
        AT_DISPATCH_INTEGRAL_TYPES(
            edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
              using edge_id_t = std::make_unsigned_t<scalar_t>;
              TORCH_CHECK(
                  num_bits <= sizeof(edge_id_t) * 8,
                  "Selected edge_id_t must be capable of storing edge_ids.");
              // Using bfloat16 for random numbers works just as reliably as
              // float32 and provides around %30 percent speedup.
              using rnd_t = nv_bfloat16;
266
267
268
269
              auto randoms =
                  allocator.AllocateStorage<rnd_t>(num_edges.value());
              auto randoms_sorted =
                  allocator.AllocateStorage<rnd_t>(num_edges.value());
270
              auto edge_id_segments =
271
                  allocator.AllocateStorage<edge_id_t>(num_edges.value());
272
              auto sorted_edge_id_segments =
273
                  allocator.AllocateStorage<edge_id_t>(num_edges.value());
274
              AT_DISPATCH_INDEX_TYPES(
275
                  indices.scalar_type(), "SampleNeighborsIndices", ([&] {
276
                    using indices_t = index_t;
277
278
279
280
281
282
283
284
285
                    auto probs_or_mask_scalar_type = torch::kFloat32;
                    if (probs_or_mask.has_value()) {
                      probs_or_mask_scalar_type =
                          probs_or_mask.value().scalar_type();
                    }
                    GRAPHBOLT_DISPATCH_ALL_TYPES(
                        probs_or_mask_scalar_type, "SampleNeighborsProbs",
                        ([&] {
                          using probs_t = scalar_t;
286
287
288
289
                          probs_t* sliced_probs_ptr = nullptr;
                          if (sliced_probs_or_mask.has_value()) {
                            sliced_probs_ptr = sliced_probs_or_mask.value()
                                                   .data_ptr<probs_t>();
290
291
292
293
294
                          }
                          const indices_t* indices_ptr =
                              layer ? indices.data_ptr<indices_t>() : nullptr;
                          const dim3 block(BLOCK_SIZE);
                          const dim3 grid(
295
296
                              (num_edges.value() + BLOCK_SIZE - 1) /
                              BLOCK_SIZE);
297
298
                          // Compute row and random number pairs.
                          CUDA_KERNEL_CALL(
299
300
                              _ComputeRandoms, grid, block, 0,
                              num_edges.value(),
301
                              sliced_indptr.data_ptr<indptr_t>(),
302
                              sub_indptr.data_ptr<indptr_t>(),
303
                              coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
304
305
306
307
308
309
310
311
                              indices_ptr, random_seed, randoms.get(),
                              edge_id_segments.get());
                        }));
                  }));

              // Sort the random numbers along with edge ids, after
              // sorting the first fanout elements of each row will
              // give us the sampled edges.
312
313
              CUB_CALL(
                  DeviceSegmentedSort::SortPairs, randoms.get(),
314
                  randoms_sorted.get(), edge_id_segments.get(),
315
                  sorted_edge_id_segments.get(), num_edges.value(), num_rows,
316
                  sub_indptr.data_ptr<indptr_t>(),
317
                  sub_indptr.data_ptr<indptr_t>() + 1);
318
319
320

              picked_eids = torch::empty(
                  static_cast<indptr_t>(num_sampled_edges),
321
                  sub_indptr.options());
322

323
324
325
326
327
328
329
330
331
              // Need to sort the sampled edges only when fanouts.size() == 1
              // since multiple fanout sampling case is automatically going to
              // be sorted.
              if (type_per_edge && fanouts.size() == 1) {
                // Ensuring sort result still ends up in sorted_edge_id_segments
                std::swap(edge_id_segments, sorted_edge_id_segments);
                auto sampled_segment_end_it = thrust::make_transform_iterator(
                    iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
                              sub_indptr.data_ptr<indptr_t>(), sampled_degree});
332
333
                CUB_CALL(
                    DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
334
335
                    sorted_edge_id_segments.get(), picked_eids.size(0),
                    num_rows, sub_indptr.data_ptr<indptr_t>(),
336
                    sampled_segment_end_it);
337
338
              }

339
340
341
342
343
344
345
346
347
348
349
350
351
352
              auto input_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFunc<indptr_t, edge_id_t>{
                            sub_indptr.data_ptr<indptr_t>(),
                            sorted_edge_id_segments.get()});
              auto output_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
                            output_indptr.data_ptr<indptr_t>(),
                            sliced_indptr.data_ptr<indptr_t>(),
                            picked_eids.data_ptr<indptr_t>()});
              constexpr int64_t max_copy_at_once =
                  std::numeric_limits<int32_t>::max();

              // Copy the sampled edge ids into picked_eids tensor.
              for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
353
354
                CUB_CALL(
                    DeviceCopy::Batched, input_buffer_it + i,
355
                    output_buffer_it + i, sampled_degree + i,
356
                    std::min(num_rows - i, max_copy_at_once));
357
358
359
360
361
362
363
364
              }
            }));

        output_indices = torch::empty(
            picked_eids.size(0),
            picked_eids.options().dtype(indices.scalar_type()));

        // Compute: output_indices = indices.gather(0, picked_eids);
365
        AT_DISPATCH_INDEX_TYPES(
366
            indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
367
              using indices_t = index_t;
368
369
              THRUST_CALL(
                  gather, picked_eids.data_ptr<indptr_t>(),
370
371
372
373
                  picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                  indices.data_ptr<indices_t>(),
                  output_indices.data_ptr<indices_t>());
            }));
374
375
376
377
378
379
380
381
382
383
384
385

        if (type_per_edge) {
          // output_type_per_edge = type_per_edge.gather(0, picked_eids);
          // The commented out torch equivalent above does not work when
          // type_per_edge is on pinned memory. That is why, we have to
          // reimplement it, similar to the indices gather operation above.
          auto types = type_per_edge.value();
          output_type_per_edge = torch::empty(
              picked_eids.size(0),
              picked_eids.options().dtype(types.scalar_type()));
          AT_DISPATCH_INTEGRAL_TYPES(
              types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
386
387
                THRUST_CALL(
                    gather, picked_eids.data_ptr<indptr_t>(),
388
389
390
391
392
                    picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                    types.data_ptr<scalar_t>(),
                    output_type_per_edge.value().data_ptr<scalar_t>());
              }));
        }
393
394
      }));

395
396
397
  // Convert output_indptr back to homo by discarding intermediate offsets.
  output_indptr =
      output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
398
399
  torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
  if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
400
401
402
  if (!nodes.has_value()) {
    nodes = torch::arange(indptr.size(0) - 1, indices.options());
  }
403
404

  return c10::make_intrusive<sampling::FusedSampledSubgraph>(
405
      output_indptr, output_indices, nodes.value(), torch::nullopt,
406
      subgraph_reverse_edge_ids, output_type_per_edge);
407
408
409
410
}

}  //  namespace ops
}  //  namespace graphbolt