neighbor_sampler.cu 27.6 KB
Newer Older
1
2
3
4
5
6
7
8
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/index_select_impl.cu
 * @brief Index select operator implementation on CUDA.
 */
#include <c10/core/ScalarType.h>
#include <curand_kernel.h>
9
#include <graphbolt/continuous_seed.h>
10
11
12
13
14
15
16
17
18
19
#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>
#include <thrust/gather.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>

#include <algorithm>
#include <array>
#include <cub/cub.cuh>
20
21
22
#if __CUDA_ARCH__ >= 700
#include <cuda/atomic>
#endif  // __CUDA_ARCH__ >= 700
23
24
25
26
27
#include <limits>
#include <numeric>
#include <type_traits>

#include "../random.h"
28
#include "../utils.h"
29
30
31
32
33
34
35
36
#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
  // To match the type of "::atomicCAS", ignore lint warning.
  using Type = unsigned long long int;  // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
}

inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
  // To match the type of "::atomicCAS", ignore lint warning.
  using Type = int;  // NOLINT

  static_assert(sizeof(Type) == sizeof(*address), "Type width must match");

  return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
}

/**
 * @brief Performs neighbor sampling and fills the edge_ids array with
 * original edge ids if sliced_indptr is valid. If not, then it fills the edge
 * ids array with numbers upto the node degree.
 */
template <typename indptr_t, typename indices_t>
__global__ void _ComputeRandomsNS(
    const int64_t num_edges, const indptr_t* const sliced_indptr,
    const indptr_t* const sub_indptr, const indptr_t* const output_indptr,
    const indices_t* const csr_rows, const uint64_t random_seed,
    indptr_t* edge_ids) {
  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride = gridDim.x * blockDim.x;

  curandStatePhilox4_32_10_t rng;
  curand_init(random_seed, i, 0, &rng);

  while (i < num_edges) {
    const auto row_position = csr_rows[i];
    const auto row_offset = i - sub_indptr[row_position];
    const auto output_offset = output_indptr[row_position];
    const auto fanout = output_indptr[row_position + 1] - output_offset;
    const auto rnd =
        row_offset < fanout ? row_offset : curand(&rng) % (row_offset + 1);
    if (rnd < fanout) {
      const indptr_t edge_id =
          row_offset + (sliced_indptr ? sliced_indptr[row_position] : 0);
#if __CUDA_ARCH__ >= 700
      ::cuda::atomic_ref<indptr_t, ::cuda::thread_scope_device> a(
          edge_ids[output_offset + rnd]);
      a.fetch_max(edge_id, ::cuda::std::memory_order_relaxed);
#else
      AtomicMax(edge_ids + output_offset + rnd, edge_id);
#endif  // __CUDA_ARCH__
    }

    i += stride;
  }
}

95
96
97
98
99
100
101
102
103
104
105
/**
 * @brief Fills the random_arr with random numbers and the edge_ids array with
 * original edge ids. When random_arr is sorted along with edge_ids, the first
 * fanout elements of each row gives us the sampled edges.
 */
template <
    typename float_t, typename indptr_t, typename indices_t, typename weights_t,
    typename edge_id_t>
__global__ void _ComputeRandoms(
    const int64_t num_edges, const indptr_t* const sliced_indptr,
    const indptr_t* const sub_indptr, const indices_t* const csr_rows,
106
    const weights_t* const sliced_weights, const indices_t* const indices,
107
108
    const continuous_seed random_seed, float_t* random_arr,
    edge_id_t* edge_ids) {
109
110
111
112
113
114
115
116
  int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
  const int stride = gridDim.x * blockDim.x;
  const auto labor = indices != nullptr;

  while (i < num_edges) {
    const auto row_position = csr_rows[i];
    const auto row_offset = i - sub_indptr[row_position];
    const auto in_idx = sliced_indptr[row_position] + row_offset;
117
    const auto rnd = random_seed.uniform(labor ? indices[in_idx] : i);
118
119
    const auto prob =
        sliced_weights ? sliced_weights[i] : static_cast<weights_t>(1);
120
121
122
123
124
125
126
127
128
129
130
    const auto exp_rnd = -__logf(rnd);
    const float_t adjusted_rnd = prob > 0
                                     ? static_cast<float_t>(exp_rnd / prob)
                                     : std::numeric_limits<float_t>::infinity();
    random_arr[i] = adjusted_rnd;
    edge_ids[i] = row_offset;

    i += stride;
  }
}

131
132
133
134
135
136
137
struct IsPositive {
  template <typename probs_t>
  __host__ __device__ auto operator()(probs_t x) {
    return x > 0;
  }
};

138
139
140
template <typename indptr_t>
struct MinInDegreeFanout {
  const indptr_t* in_degree;
141
142
  const int64_t* fanouts;
  size_t num_fanouts;
143
144
  __host__ __device__ auto operator()(int64_t i) {
    return static_cast<indptr_t>(
145
        min(static_cast<int64_t>(in_degree[i]), fanouts[i % num_fanouts]));
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFunc {
  indptr_t* indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) { return indices + indptr[i]; }
};

template <typename indptr_t>
struct AddOffset {
  indptr_t offset;
  template <typename edge_id_t>
  __host__ __device__ indptr_t operator()(edge_id_t x) {
    return x + offset;
  }
};

template <typename indptr_t, typename indices_t>
struct IteratorFuncAddOffset {
  indptr_t* indptr;
  indptr_t* sliced_indptr;
  indices_t* indices;
  __host__ __device__ auto operator()(int64_t i) {
    return thrust::transform_output_iterator{
        indices + indptr[i], AddOffset<indptr_t>{sliced_indptr[i]}};
  }
};

176
177
178
179
180
181
182
183
184
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc {
  indptr_t* indptr;
  in_degree_iterator_t in_degree;
  __host__ __device__ auto operator()(int64_t i) {
    return indptr[i] + in_degree[i];
  }
};

185
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
186
    torch::Tensor indptr, torch::Tensor indices,
187
188
189
190
    torch::optional<torch::Tensor> seeds,
    torch::optional<std::vector<int64_t>> seed_offsets,
    const std::vector<int64_t>& fanouts, bool replace, bool layer,
    bool return_eids, torch::optional<torch::Tensor> type_per_edge,
191
    torch::optional<torch::Tensor> probs_or_mask,
192
193
    torch::optional<torch::Dict<std::string, int64_t>> node_type_to_id,
    torch::optional<torch::Dict<std::string, int64_t>> edge_type_to_id,
194
195
    torch::optional<torch::Tensor> random_seed_tensor,
    float seed2_contribution) {
196
197
198
199
  // When seed_offsets.has_value() in the hetero case, we compute the output of
  // sample_neighbors _convert_to_sampled_subgraph in a fused manner so that
  // _convert_to_sampled_subgraph only has to perform slices over the returned
  // indptr and indices tensors to form CSC outputs for each edge type.
200
  TORCH_CHECK(!replace, "Sampling with replacement is not supported yet!");
201
  // Assume that indptr, indices, seeds, type_per_edge and probs_or_mask
202
203
204
  // are all resident on the GPU. If not, it is better to first extract them
  // before calling this function.
  auto allocator = cuda::GetAllocator();
205
  auto num_rows =
206
      seeds.has_value() ? seeds.value().size(0) : indptr.size(0) - 1;
207
208
209
210
211
212
213
214
215
216
217
218
  auto fanouts_pinned = torch::empty(
      fanouts.size(),
      c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
  auto fanouts_pinned_ptr = fanouts_pinned.data_ptr<int64_t>();
  for (size_t i = 0; i < fanouts.size(); i++) {
    fanouts_pinned_ptr[i] =
        fanouts[i] >= 0 ? fanouts[i] : std::numeric_limits<int64_t>::max();
  }
  // Finally, copy the adjusted fanout values to the device memory.
  auto fanouts_device = allocator.AllocateStorage<int64_t>(fanouts.size());
  CUDA_CALL(cudaMemcpyAsync(
      fanouts_device.get(), fanouts_pinned_ptr,
219
220
      sizeof(int64_t) * fanouts.size(), cudaMemcpyHostToDevice,
      cuda::GetCurrentStream()));
221
  auto in_degree_and_sliced_indptr = SliceCSCIndptr(indptr, seeds);
222
  auto in_degree = std::get<0>(in_degree_and_sliced_indptr);
223
  auto sliced_indptr = std::get<1>(in_degree_and_sliced_indptr);
224
225
226
227
228
229
230
231
232
  auto max_in_degree = torch::empty(
      1,
      c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
  AT_DISPATCH_INDEX_TYPES(
      indptr.scalar_type(), "SampleNeighborsMaxInDegree", ([&] {
        CUB_CALL(
            DeviceReduce::Max, in_degree.data_ptr<index_t>(),
            max_in_degree.data_ptr<index_t>(), num_rows);
      }));
233
234
235
236
  // Protect access to max_in_degree with a CUDAEvent
  at::cuda::CUDAEvent max_in_degree_event;
  max_in_degree_event.record();
  torch::optional<int64_t> num_edges;
237
  torch::Tensor sub_indptr;
238
  if (!seeds.has_value()) {
239
240
241
    num_edges = indices.size(0);
    sub_indptr = indptr;
  }
242
243
  torch::optional<torch::Tensor> sliced_probs_or_mask;
  if (probs_or_mask.has_value()) {
244
    if (seeds.has_value()) {
245
246
      torch::Tensor sliced_probs_or_mask_tensor;
      std::tie(sub_indptr, sliced_probs_or_mask_tensor) = IndexSelectCSCImpl(
247
          in_degree, sliced_indptr, probs_or_mask.value(), seeds.value(),
248
249
250
251
252
253
          indptr.size(0) - 2, num_edges);
      sliced_probs_or_mask = sliced_probs_or_mask_tensor;
      num_edges = sliced_probs_or_mask_tensor.size(0);
    } else {
      sliced_probs_or_mask = probs_or_mask;
    }
254
  }
255
256
  if (fanouts.size() > 1) {
    torch::Tensor sliced_type_per_edge;
257
    if (seeds.has_value()) {
258
      std::tie(sub_indptr, sliced_type_per_edge) = IndexSelectCSCImpl(
259
          in_degree, sliced_indptr, type_per_edge.value(), seeds.value(),
260
261
262
263
          indptr.size(0) - 2, num_edges);
    } else {
      sliced_type_per_edge = type_per_edge.value();
    }
264
265
266
    std::tie(sub_indptr, in_degree, sliced_indptr) = SliceCSCIndptrHetero(
        sub_indptr, sliced_type_per_edge, sliced_indptr, fanouts.size());
    num_rows = sliced_indptr.size(0);
267
    num_edges = sliced_type_per_edge.size(0);
268
269
  }
  // If sub_indptr was not computed in the two code blocks above:
270
  if (seeds.has_value() && !probs_or_mask.has_value() && fanouts.size() <= 1) {
271
    sub_indptr = ExclusiveCumSum(in_degree);
272
  }
273
274
275
276
277
278
279
280
  const continuous_seed random_seed = [&] {
    if (random_seed_tensor.has_value()) {
      return continuous_seed(random_seed_tensor.value(), seed2_contribution);
    } else {
      return continuous_seed{RandomEngine::ThreadLocal()->RandInt(
          static_cast<int64_t>(0), std::numeric_limits<int64_t>::max())};
    }
  }();
281
  auto output_indptr = torch::empty_like(sub_indptr);
282
283
284
  torch::Tensor picked_eids;
  torch::Tensor output_indices;

285
  AT_DISPATCH_INDEX_TYPES(
286
      indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
287
        using indptr_t = index_t;
288
289
290
291
292
293
294
295
        if (probs_or_mask.has_value()) {  // Count nonzero probs into in_degree.
          GRAPHBOLT_DISPATCH_ALL_TYPES(
              probs_or_mask.value().scalar_type(),
              "SampleNeighborsPositiveProbs", ([&] {
                using probs_t = scalar_t;
                auto is_nonzero = thrust::make_transform_iterator(
                    sliced_probs_or_mask.value().data_ptr<probs_t>(),
                    IsPositive{});
296
297
                CUB_CALL(
                    DeviceSegmentedReduce::Sum, is_nonzero,
298
299
                    in_degree.data_ptr<indptr_t>(), num_rows,
                    sub_indptr.data_ptr<indptr_t>(),
300
                    sub_indptr.data_ptr<indptr_t>() + 1);
301
302
              }));
        }
303
304
305
        thrust::counting_iterator<int64_t> iota(0);
        auto sampled_degree = thrust::make_transform_iterator(
            iota, MinInDegreeFanout<indptr_t>{
306
307
                      in_degree.data_ptr<indptr_t>(), fanouts_device.get(),
                      fanouts.size()});
308

309
310
311
312
        // Compute output_indptr.
        CUB_CALL(
            DeviceScan::ExclusiveSum, sampled_degree,
            output_indptr.data_ptr<indptr_t>(), num_rows + 1);
313
314
315
316

        auto num_sampled_edges =
            cuda::CopyScalar{output_indptr.data_ptr<indptr_t>() + num_rows};

317
318
319
320
321
322
        // This operation is placed after num_sampled_edges copy is started to
        // hide the latency of copy synchronization later.
        auto coo_rows = ExpandIndptrImpl(
            sub_indptr, indices.scalar_type(), torch::nullopt, num_edges);
        num_edges = coo_rows.size(0);

323
324
        // Find the smallest integer type to store the edge id offsets. We synch
        // the CUDAEvent so that the access is safe.
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        auto compute_num_bits = [&] {
          max_in_degree_event.synchronize();
          return cuda::NumberOfBits(max_in_degree.data_ptr<indptr_t>()[0]);
        };
        if (layer || probs_or_mask.has_value()) {
          const int num_bits = compute_num_bits();
          std::array<int, 4> type_bits = {8, 16, 32, 64};
          const auto type_index =
              std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
              type_bits.begin();
          std::array<torch::ScalarType, 5> types = {
              torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
              torch::kLong};
          auto edge_id_dtype = types[type_index];
          AT_DISPATCH_INTEGRAL_TYPES(
              edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
                using edge_id_t = std::make_unsigned_t<scalar_t>;
                TORCH_CHECK(
                    num_bits <= sizeof(edge_id_t) * 8,
                    "Selected edge_id_t must be capable of storing edge_ids.");
                // Using bfloat16 for random numbers works just as reliably as
                // float32 and provides around 30% speedup.
                using rnd_t = nv_bfloat16;
                auto randoms =
                    allocator.AllocateStorage<rnd_t>(num_edges.value());
                auto randoms_sorted =
                    allocator.AllocateStorage<rnd_t>(num_edges.value());
                auto edge_id_segments =
                    allocator.AllocateStorage<edge_id_t>(num_edges.value());
                auto sorted_edge_id_segments =
                    allocator.AllocateStorage<edge_id_t>(num_edges.value());
                AT_DISPATCH_INDEX_TYPES(
                    indices.scalar_type(), "SampleNeighborsIndices", ([&] {
                      using indices_t = index_t;
                      auto probs_or_mask_scalar_type = torch::kFloat32;
                      if (probs_or_mask.has_value()) {
                        probs_or_mask_scalar_type =
                            probs_or_mask.value().scalar_type();
                      }
                      GRAPHBOLT_DISPATCH_ALL_TYPES(
                          probs_or_mask_scalar_type, "SampleNeighborsProbs",
                          ([&] {
                            using probs_t = scalar_t;
                            probs_t* sliced_probs_ptr = nullptr;
                            if (sliced_probs_or_mask.has_value()) {
                              sliced_probs_ptr = sliced_probs_or_mask.value()
                                                     .data_ptr<probs_t>();
                            }
                            const indices_t* indices_ptr =
                                layer ? indices.data_ptr<indices_t>() : nullptr;
                            const dim3 block(BLOCK_SIZE);
                            const dim3 grid(
                                (num_edges.value() + BLOCK_SIZE - 1) /
                                BLOCK_SIZE);
                            // Compute row and random number pairs.
                            CUDA_KERNEL_CALL(
                                _ComputeRandoms, grid, block, 0,
                                num_edges.value(),
                                sliced_indptr.data_ptr<indptr_t>(),
                                sub_indptr.data_ptr<indptr_t>(),
                                coo_rows.data_ptr<indices_t>(),
                                sliced_probs_ptr, indices_ptr, random_seed,
                                randoms.get(), edge_id_segments.get());
                          }));
                    }));

                // Sort the random numbers along with edge ids, after
                // sorting the first fanout elements of each row will
                // give us the sampled edges.
394
                CUB_CALL(
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
                    DeviceSegmentedSort::SortPairs, randoms.get(),
                    randoms_sorted.get(), edge_id_segments.get(),
                    sorted_edge_id_segments.get(), num_edges.value(), num_rows,
                    sub_indptr.data_ptr<indptr_t>(),
                    sub_indptr.data_ptr<indptr_t>() + 1);

                picked_eids = torch::empty(
                    static_cast<indptr_t>(num_sampled_edges),
                    sub_indptr.options());

                // Need to sort the sampled edges only when fanouts.size() == 1
                // since multiple fanout sampling case is automatically going to
                // be sorted.
                if (type_per_edge && fanouts.size() == 1) {
                  // Ensuring sort result still ends up in
                  // sorted_edge_id_segments
                  std::swap(edge_id_segments, sorted_edge_id_segments);
                  auto sampled_segment_end_it = thrust::make_transform_iterator(
                      iota,
                      SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
                          sub_indptr.data_ptr<indptr_t>(), sampled_degree});
                  CUB_CALL(
                      DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
                      sorted_edge_id_segments.get(), picked_eids.size(0),
                      num_rows, sub_indptr.data_ptr<indptr_t>(),
                      sampled_segment_end_it);
                }

                auto input_buffer_it = thrust::make_transform_iterator(
                    iota, IteratorFunc<indptr_t, edge_id_t>{
                              sub_indptr.data_ptr<indptr_t>(),
                              sorted_edge_id_segments.get()});
                auto output_buffer_it = thrust::make_transform_iterator(
                    iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
                              output_indptr.data_ptr<indptr_t>(),
                              sliced_indptr.data_ptr<indptr_t>(),
                              picked_eids.data_ptr<indptr_t>()});
                constexpr int64_t max_copy_at_once =
                    std::numeric_limits<int32_t>::max();

                // Copy the sampled edge ids into picked_eids tensor.
                for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
                  CUB_CALL(
                      DeviceCopy::Batched, input_buffer_it + i,
                      output_buffer_it + i, sampled_degree + i,
                      std::min(num_rows - i, max_copy_at_once));
                }
              }));
        } else {  // Non-weighted neighbor sampling.
          picked_eids = torch::zeros(num_edges.value(), sub_indptr.options());
          const auto sort_needed = type_per_edge && fanouts.size() == 1;
          const auto sliced_indptr_ptr =
              sort_needed ? nullptr : sliced_indptr.data_ptr<indptr_t>();

          const dim3 block(BLOCK_SIZE);
          const dim3 grid(
              (std::min(num_edges.value(), static_cast<int64_t>(1 << 20)) +
               BLOCK_SIZE - 1) /
              BLOCK_SIZE);
          AT_DISPATCH_INDEX_TYPES(
              indices.scalar_type(), "SampleNeighborsIndices", ([&] {
                using indices_t = index_t;
                // Compute row and random number pairs.
                CUDA_KERNEL_CALL(
                    _ComputeRandomsNS, grid, block, 0, num_edges.value(),
                    sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(),
                    output_indptr.data_ptr<indptr_t>(),
                    coo_rows.data_ptr<indices_t>(), random_seed.get_seed(0),
                    picked_eids.data_ptr<indptr_t>());
              }));

          picked_eids =
              picked_eids.slice(0, 0, static_cast<indptr_t>(num_sampled_edges));

          // Need to sort the sampled edges only when fanouts.size() == 1
          // since multiple fanout sampling case is automatically going to
          // be sorted.
          if (sort_needed) {
            const int num_bits = compute_num_bits();
            std::array<int, 4> type_bits = {8, 15, 31, 63};
            const auto type_index =
                std::lower_bound(type_bits.begin(), type_bits.end(), num_bits) -
                type_bits.begin();
            std::array<torch::ScalarType, 5> types = {
                torch::kByte, torch::kInt16, torch::kInt32, torch::kLong,
                torch::kLong};
            auto edge_id_dtype = types[type_index];
            AT_DISPATCH_INTEGRAL_TYPES(
                edge_id_dtype, "SampleNeighborsEdgeIDs", ([&] {
                  using edge_id_t = scalar_t;
                  TORCH_CHECK(
                      num_bits <= sizeof(edge_id_t) * 8,
                      "Selected edge_id_t must be capable of storing "
                      "edge_ids.");
                  auto picked_offsets = picked_eids.to(edge_id_dtype);
                  auto sorted_offsets = torch::empty_like(picked_offsets);
                  CUB_CALL(
                      DeviceSegmentedSort::SortKeys,
                      picked_offsets.data_ptr<edge_id_t>(),
                      sorted_offsets.data_ptr<edge_id_t>(), picked_eids.size(0),
                      num_rows, output_indptr.data_ptr<indptr_t>(),
                      output_indptr.data_ptr<indptr_t>() + 1);
                  auto edge_id_offsets = ExpandIndptrImpl(
                      output_indptr, picked_eids.scalar_type(), sliced_indptr,
                      picked_eids.size(0));
                  picked_eids = sorted_offsets.to(picked_eids.scalar_type()) +
                                edge_id_offsets;
                }));
          }
        }
505

506
        output_indices = Gather(indices, picked_eids);
507
      }));
508

509
510
511
512
513
514
515
516
517
  torch::optional<torch::Tensor> output_type_per_edge;
  torch::optional<torch::Tensor> edge_offsets;
  if (type_per_edge && seed_offsets) {
    const int64_t num_etypes =
        edge_type_to_id.has_value() ? edge_type_to_id->size() : 1;
    // If we performed homogenous sampling on hetero graph, we have to look at
    // type_per_edge of sampled edges and determine the offsets of different
    // sampled etypes and convert to fused hetero indptr representation.
    if (fanouts.size() == 1) {
518
      output_type_per_edge = Gather(*type_per_edge, picked_eids);
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
      torch::Tensor output_in_degree, sliced_output_indptr;
      sliced_output_indptr =
          output_indptr.slice(0, 0, output_indptr.size(0) - 1);
      std::tie(output_indptr, output_in_degree, sliced_output_indptr) =
          SliceCSCIndptrHetero(
              output_indptr, output_type_per_edge.value(), sliced_output_indptr,
              num_etypes);
      // We use num_rows to hold num_seeds * num_etypes. So, it needs to be
      // updated when sampling with a single fanout value when the graph is
      // heterogenous.
      num_rows = sliced_output_indptr.size(0);
    }
    // Here, we check what are the dst node types for the given seeds so that
    // we can compute the output indptr space later.
    std::vector<int64_t> etype_id_to_dst_ntype_id(num_etypes);
    for (auto& etype_and_id : edge_type_to_id.value()) {
      auto etype = etype_and_id.key();
      auto id = etype_and_id.value();
      auto dst_type = utils::parse_dst_ntype_from_etype(etype);
      etype_id_to_dst_ntype_id[id] = node_type_to_id->at(dst_type);
    }
    // For each edge type, we compute the start and end offsets to index into
    // indptr to form the final output_indptr.
    auto indptr_offsets = torch::empty(
        num_etypes * 2,
        c10::TensorOptions().dtype(torch::kLong).pinned_memory(true));
    auto indptr_offsets_ptr = indptr_offsets.data_ptr<int64_t>();
    // We compute the indptr offsets here, right now, output_indptr is of size
    // # seeds * num_etypes + 1. We can simply take slices to get correct output
    // indptr. The final output_indptr is same as current indptr except that
    // some intermediate values are removed to change the node ids space from
    // all of the seed vertices to the node id space of the dst node type of
    // each edge type.
    for (int i = 0; i < num_etypes; i++) {
      indptr_offsets_ptr[2 * i] = num_rows / num_etypes * i +
                                  seed_offsets->at(etype_id_to_dst_ntype_id[i]);
      indptr_offsets_ptr[2 * i + 1] =
          num_rows / num_etypes * i +
          seed_offsets->at(etype_id_to_dst_ntype_id[i] + 1);
    }
    auto permutation = torch::arange(
        0, num_rows * num_etypes, num_etypes, output_indptr.options());
    permutation =
        permutation.remainder(num_rows) + permutation.div(num_rows, "floor");
    // This permutation, when applied sorts the sampled edges with respect to
    // edge types.
    auto [output_in_degree, sliced_output_indptr] =
        SliceCSCIndptr(output_indptr, permutation);
    std::tie(output_indptr, picked_eids) = IndexSelectCSCImpl(
        output_in_degree, sliced_output_indptr, picked_eids, permutation,
        num_rows - 1, picked_eids.size(0));
    edge_offsets = torch::empty(
        num_etypes * 2, c10::TensorOptions()
                            .dtype(output_indptr.scalar_type())
                            .pinned_memory(true));
    at::cuda::CUDAEvent edge_offsets_event;
    AT_DISPATCH_INDEX_TYPES(
        indptr.scalar_type(), "SampleNeighborsEdgeOffsets", ([&] {
          THRUST_CALL(
              gather, indptr_offsets_ptr,
              indptr_offsets_ptr + indptr_offsets.size(0),
              output_indptr.data_ptr<index_t>(),
              edge_offsets->data_ptr<index_t>());
        }));
    edge_offsets_event.record();
    // The output_indices is permuted here.
    std::tie(output_indptr, output_indices) = IndexSelectCSCImpl(
        output_in_degree, sliced_output_indptr, output_indices, permutation,
        num_rows - 1, output_indices.size(0));
    std::vector<torch::Tensor> indptr_list;
    for (int i = 0; i < num_etypes; i++) {
      indptr_list.push_back(output_indptr.slice(
          0, indptr_offsets_ptr[2 * i],
          indptr_offsets_ptr[2 * i + 1] + (i == num_etypes - 1)));
    }
    // We form the final output indptr by concatenating pieces for different
    // edge types.
    output_indptr = torch::cat(indptr_list);
    edge_offsets_event.synchronize();
    // We read the edge_offsets here, they are in pairs but we don't need it to
    // be in pairs. So we remove the duplicate information from it and turn it
    // into a real offsets array.
    AT_DISPATCH_INDEX_TYPES(
        indptr.scalar_type(), "SampleNeighborsEdgeOffsetsCheck", ([&] {
          auto edge_offsets_ptr = edge_offsets->data_ptr<index_t>();
          TORCH_CHECK(edge_offsets_ptr[0] == 0, "edge_offsets is incorrect.");
          for (int i = 1; i < num_etypes; i++) {
            TORCH_CHECK(
                edge_offsets_ptr[2 * i - 1] == edge_offsets_ptr[2 * i],
                "edge_offsets is incorrect.");
          }
          TORCH_CHECK(
              edge_offsets_ptr[2 * num_etypes - 1] == picked_eids.size(0),
              "edge_offsets is incorrect.");
          for (int i = 0; i < num_etypes; i++) {
            edge_offsets_ptr[i + 1] = edge_offsets_ptr[2 * i + 1];
          }
        }));
    edge_offsets = edge_offsets->slice(0, 0, num_etypes + 1);
  } else {
    // Convert output_indptr back to homo by discarding intermediate offsets.
    output_indptr =
        output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
    if (type_per_edge)
623
      output_type_per_edge = Gather(*type_per_edge, picked_eids);
624
  }
625
626
627
628
629

  torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
  if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);

  return c10::make_intrusive<sampling::FusedSampledSubgraph>(
630
631
      output_indptr, output_indices, seeds, torch::nullopt,
      subgraph_reverse_edge_ids, output_type_per_edge, edge_offsets);
632
633
634
635
}

}  //  namespace ops
}  //  namespace graphbolt