neighbor_sampler.cu 16.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/index_select_impl.cu
 * @brief Index select operator implementation on CUDA.
 */
#include <c10/core/ScalarType.h>
#include <curand_kernel.h>
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <algorithm>
#include <array>
#include <cub/cub.cuh>
#include <limits>
#include <numeric>
#include <type_traits>

#include "../random.h"
#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

/**
 * @brief Fills the random_arr with random numbers and the edge_ids array with
 * original edge ids. When random_arr is sorted along with edge_ids, the first
 * fanout elements of each row gives us the sampled edges.
 */
template <
    typename float_t, typename indptr_t, typename indices_t, typename weights_t,
    typename edge_id_t>
__global__ void _ComputeRandoms(
    const int64_t num_edges, const indptr_t* const sliced_indptr,
    const indptr_t* const sub_indptr, const indices_t* const csr_rows,
43
    const weights_t* const sliced_weights, const indices_t* const indices,
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    const uint64_t random_seed, float_t* random_arr, edge_id_t* edge_ids) {
  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride = gridDim.x * blockDim.x;
  curandStatePhilox4_32_10_t rng;
  const auto labor = indices != nullptr;

  if (!labor) {
    curand_init(random_seed, i, 0, &rng);
  }

  while (i < num_edges) {
    const auto row_position = csr_rows[i];
    const auto row_offset = i - sub_indptr[row_position];
    const auto in_idx = sliced_indptr[row_position] + row_offset;

    if (labor) {
      constexpr uint64_t kCurandSeed = 999961;
      curand_init(kCurandSeed, random_seed, indices[in_idx], &rng);
    }

    const auto rnd = curand_uniform(&rng);
65
66
    const auto prob =
        sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
67
68
69
70
71
72
73
74
75
76
77
    const auto exp_rnd = -__logf(rnd);
    const float_t adjusted_rnd = prob > 0
                                     ? static_cast<float_t>(exp_rnd / prob)
                                     : std::numeric_limits<float_t>::infinity();
    random_arr[i] = adjusted_rnd;
    edge_ids[i] = row_offset;

    i += stride;
  }
}

78
79
80
81
82
83
84
struct IsPositive {
  template <typename probs_t>
  __host__ __device__ auto operator()(probs_t x) {
    return x > 0;
  }
};

85
86
87
template <typename indptr_t>
struct MinInDegreeFanout {
  const indptr_t* in_degree;
88
89
  const int64_t* fanouts;
  size_t num_fanouts;
90
91
  __host__ __device__ auto operator()(int64_t i) {
    return static_cast<indptr_t>(
92
        min(static_cast<int64_t>(in_degree[i]), fanouts[i % num_fanouts]));
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
  indptr_t* indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t>
struct AddOffset {
  indptr_t offset;
  template <typename edge_id_t>
  __host__ __device__ indptr_t operator()(edge_id_t x) {
    return x + offset;
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFuncAddOffset {
  indptr_t* indptr;
  indptr_t* sliced_indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) {
    return thrust::transform_output_iterator{
        indices + indptr[i], AddOffset<indptr_t>{sliced_indptr[i]}};
  }
};

123
124
125
126
127
128
129
130
131
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc {
  indptr_t* indptr;
  in_degree_iterator_t in_degree;
  __host__ __device__ auto operator()(int64_t i) {
    return indptr[i] + in_degree[i];
  }
};

132
133
134
135
136
137
138
139
140
141
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
    const std::vector<int64_t>& fanouts, bool replace, bool layer,
    bool return_eids, torch::optional<torch::Tensor> type_per_edge,
    torch::optional<torch::Tensor> probs_or_mask) {
  TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
  // Assume that indptr, indices, nodes, type_per_edge and probs_or_mask
  // are all resident on the GPU. If not, it is better to first extract them
  // before calling this function.
  auto allocator = cuda::GetAllocator();
142
143
144
145
146
147
148
149
150
151
152
153
154
  auto num_rows = nodes.size(0);
  auto fanouts_pinned = torch::empty(
      fanouts.size(),
      c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
  auto fanouts_pinned_ptr = fanouts_pinned.data_ptr<int64_t>();
  for (size_t i = 0; i < fanouts.size(); i++) {
    fanouts_pinned_ptr[i] =
        fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits<int64_t>::max();
  }
  // Finally, copy the adjusted fanout values to the device memory.
  auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size());
  CUDA_CALL(cudaMemcpyAsync(
      fanouts_device.get(), fanouts_pinned_ptr,
155
156
      sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,
      cuda::GetCurrentStream()));
157
158
  auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, nodes);
  auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
159
  auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
160
161
162
163
164
165
166
167
168
  auto max_in_degree = torch::empty(
      1,
      c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
  AT_DISPATCH_INDEX_TYPES(
      indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] {
        CUB_CALL(
            DeviceReduce::Max, in_degree.data_ptr<index_t>(),
            max_in_degree.data_ptr<index_t>(), num_rows);
      }));
169
  torch::optional<int64_t> num_edges_;
170
171
172
173
  torch::Tensor sub_indptr;
  torch::optional<torch::Tensor> sliced_probs_or_mask;
  if (probs_or_mask.has_value()) {
    torch::Tensor sliced_probs_or_mask_tensor;
174
175
176
    std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
        in_degree, sliced_indptr, probs_or_mask.value(), nodes,
        indptr.size(0) - 2, num_edges_);
177
    sliced_probs_or_mask = sliced_probs_or_mask_tensor;
178
    num_edges_ = sliced_probs_or_mask_tensor.size(0);
179
  }
180
181
  if (fanouts.size() > 1) {
    torch::Tensor sliced_type_per_edge;
182
183
184
    std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
        in_degree, sliced_indptr, type_per_edge.value(), nodes,
        indptr.size(0) - 2, num_edges_);
185
186
187
    std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(
        sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());
    num_rows = sliced_indptr.size(0);
188
189
190
191
192
    num_edges_ = sliced_type_per_edge.size(0);
  }
  // If sub_indptr was not computed in the two code blocks above:
  if (!probs_or_mask.has_value() && fanouts.size() <= 1) {
    sub_indptr = ExclusiveCumSum(in_degree);
193
  }
194
195
  auto coo_rows = ExpandIndptrImpl(
      sub_indptr, indices.scalar_type(), torch::nullopt, num_edges_);
196
197
198
  const auto num_edges = coo_rows.size(0);
  const auto random_seed = RandomEngine::ThreadLocal()->RandInt(
      static_cast<int64_t>(0), std::numeric_limits<int64_t>::max());
199
  auto output_indptr = torch::empty_like(sub_indptr);
200
201
  torch::Tensor picked_eids;
  torch::Tensor output_indices;
202
  torch::optional<torch::Tensor> output_type_per_edge;
203

204
  AT_DISPATCH_INDEX_TYPES(
205
      indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
206
        using indptr_t = index_t;
207
208
209
210
211
212
213
214
        if (probs_or_mask.has_value()) {  // Count nonzero probs into in_degree.
          GRAPHBOLT_DISPATCH_ALL_TYPES(
              probs_or_mask.value().scalar_type(),
              "SampleNeighborsPositiveProbs", ([&] {
                using probs_t = scalar_t;
                auto is_nonzero = thrust::make_transform_iterator(
                    sliced_probs_or_mask.value().data_ptr<probs_t>(),
                    IsPositive{});
215
216
                CUB_CALL(
                    DeviceSegmentedReduce::Sum, is_nonzero,
217
218
                    in_degree.data_ptr<indptr_t>(), num_rows,
                    sub_indptr.data_ptr<indptr_t>(),
219
                    sub_indptr.data_ptr<indptr_t>() + 1);
220
221
              }));
        }
222
223
224
        thrust::counting_iterator<int64_t> iota(0);
        auto sampled_degree = thrust::make_transform_iterator(
            iota, MinInDegreeFanout<indptr_t>{
225
226
                      in_degree.data_ptr<indptr_t>(), fanouts_device.get(),
                      fanouts.size()});
227

228
229
230
231
        // Compute output_indptr.
        CUB_CALL(
            DeviceScan::ExclusiveSum, sampled_degree,
            output_indptr.data_ptr<indptr_t>(), num_rows + 1);
232
233
234
235
236

        auto num_sampled_edges =
            cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};

        // Find the smallest integer type to store the edge id offsets.
237
238
        // ExpandIndptr or IndexSelectCSCImpl had synch inside, so it is safe to
        // read max_in_degree now.
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        const int num_bits =
            cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
        std::array<int, 4> type_bits = {8, 16, 32, 64};
        const auto type_index =
            std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
            type_bits.begin();
        std::array<torch::ScalarType, 5> types = {
            torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
            torch::kLong};
        auto edge_id_dtype = types[type_index];
        AT_DISPATCH_INTEGRAL_TYPES(
            edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
              using edge_id_t = std::make_unsigned_t<scalar_t>;
              TORCH_CHECK(
                  num_bits <= sizeof(edge_id_t) * 8,
                  "Selected edge_id_t must be capable of storing edge_ids.");
              // Using bfloat16 for random numbers works just as reliably as
              // float32 and provides around %30 percent speedup.
              using rnd_t = nv_bfloat16;
              auto randoms = allocator.AllocateStorage<rnd_t>(num_edges);
              auto randoms_sorted = allocator.AllocateStorage<rnd_t>(num_edges);
              auto edge_id_segments =
                  allocator.AllocateStorage<edge_id_t>(num_edges);
              auto sorted_edge_id_segments =
                  allocator.AllocateStorage<edge_id_t>(num_edges);
264
              AT_DISPATCH_INDEX_TYPES(
265
                  indices.scalar_type(), "SampleNeighborsIndices", ([&] {
266
                    using indices_t = index_t;
267
268
269
270
271
272
273
274
275
                    auto probs_or_mask_scalar_type = torch::kFloat32;
                    if (probs_or_mask.has_value()) {
                      probs_or_mask_scalar_type =
                          probs_or_mask.value().scalar_type();
                    }
                    GRAPHBOLT_DISPATCH_ALL_TYPES(
                        probs_or_mask_scalar_type, "SampleNeighborsProbs",
                        ([&] {
                          using probs_t = scalar_t;
276
277
278
279
                          probs_t* sliced_probs_ptr = nullptr;
                          if (sliced_probs_or_mask.has_value()) {
                            sliced_probs_ptr = sliced_probs_or_mask.value()
                                                   .data_ptr<probs_t>();
280
281
282
283
284
285
286
287
                          }
                          const indices_t* indices_ptr =
                              layer ? indices.data_ptr<indices_t>() : nullptr;
                          const dim3 block(BLOCK_SIZE);
                          const dim3 grid(
                              (num_edges + BLOCK_SIZE - 1) / BLOCK_SIZE);
                          // Compute row and random number pairs.
                          CUDA_KERNEL_CALL(
288
289
                              _ComputeRandoms, grid, block, 0, num_edges,
                              sliced_indptr.data_ptr<indptr_t>(),
290
                              sub_indptr.data_ptr<indptr_t>(),
291
                              coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
292
293
294
295
296
297
298
299
                              indices_ptr, random_seed, randoms.get(),
                              edge_id_segments.get());
                        }));
                  }));

              // Sort the random numbers along with edge ids, after
              // sorting the first fanout elements of each row will
              // give us the sampled edges.
300
301
              CUB_CALL(
                  DeviceSegmentedSort::SortPairs, randoms.get(),
302
303
304
                  randoms_sorted.get(), edge_id_segments.get(),
                  sorted_edge_id_segments.get(), num_edges, num_rows,
                  sub_indptr.data_ptr<indptr_t>(),
305
                  sub_indptr.data_ptr<indptr_t>() + 1);
306
307
308
309
310

              picked_eids = torch::empty(
                  static_cast<indptr_t>(num_sampled_edges),
                  nodes.options().dtype(indptr.scalar_type()));

311
312
313
314
315
316
317
318
319
              // Need to sort the sampled edges only when fanouts.size() == 1
              // since multiple fanout sampling case is automatically going to
              // be sorted.
              if (type_per_edge && fanouts.size() == 1) {
                // Ensuring sort result still ends up in sorted_edge_id_segments
                std::swap(edge_id_segments, sorted_edge_id_segments);
                auto sampled_segment_end_it = thrust::make_transform_iterator(
                    iota, SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
                              sub_indptr.data_ptr<indptr_t>(), sampled_degree});
320
321
                CUB_CALL(
                    DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
322
323
                    sorted_edge_id_segments.get(), picked_eids.size(0),
                    num_rows, sub_indptr.data_ptr<indptr_t>(),
324
                    sampled_segment_end_it);
325
326
              }

327
328
329
330
331
332
333
334
335
336
337
338
339
340
              auto input_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFunc<indptr_t, edge_id_t>{
                            sub_indptr.data_ptr<indptr_t>(),
                            sorted_edge_id_segments.get()});
              auto output_buffer_it = thrust::make_transform_iterator(
                  iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
                            output_indptr.data_ptr<indptr_t>(),
                            sliced_indptr.data_ptr<indptr_t>(),
                            picked_eids.data_ptr<indptr_t>()});
              constexpr int64_t max_copy_at_once =
                  std::numeric_limits<int32_t>::max();

              // Copy the sampled edge ids into picked_eids tensor.
              for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
341
342
                CUB_CALL(
                    DeviceCopy::Batched, input_buffer_it + i,
343
                    output_buffer_it + i, sampled_degree + i,
344
                    std::min(num_rows - i, max_copy_at_once));
345
346
347
348
349
350
351
352
              }
            }));

        output_indices = torch::empty(
            picked_eids.size(0),
            picked_eids.options().dtype(indices.scalar_type()));

        // Compute: output_indices = indices.gather(0, picked_eids);
353
        AT_DISPATCH_INDEX_TYPES(
354
            indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
355
              using indices_t = index_t;
356
357
              THRUST_CALL(
                  gather, picked_eids.data_ptr<indptr_t>(),
358
359
360
361
                  picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                  indices.data_ptr<indices_t>(),
                  output_indices.data_ptr<indices_t>());
            }));
362
363
364
365
366
367
368
369
370
371
372
373

        if (type_per_edge) {
          // output_type_per_edge = type_per_edge.gather(0, picked_eids);
          // The commented out torch equivalent above does not work when
          // type_per_edge is on pinned memory. That is why, we have to
          // reimplement it, similar to the indices gather operation above.
          auto types = type_per_edge.value();
          output_type_per_edge = torch::empty(
              picked_eids.size(0),
              picked_eids.options().dtype(types.scalar_type()));
          AT_DISPATCH_INTEGRAL_TYPES(
              types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
374
375
                THRUST_CALL(
                    gather, picked_eids.data_ptr<indptr_t>(),
376
377
378
379
380
                    picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
                    types.data_ptr<scalar_t>(),
                    output_type_per_edge.value().data_ptr<scalar_t>());
              }));
        }
381
382
      }));

383
384
385
  // Convert output_indptr back to homo by discarding intermediate offsets.
  output_indptr =
      output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
386
387
388
389
390
  torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
  if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);

  return c10::make_intrusive<sampling::FusedSampledSubgraph>(
      output_indptr, output_indices, nodes, torch::nullopt,
391
      subgraph_reverse_edge_ids, output_type_per_edge);
392
393
394
395
}

}  //  namespace ops
}  //  namespace graphbolt