"vscode:/vscode.git/clone" did not exist on "b6c55b62d72c0df3a80c2b326679bea6fa26a0cf"
sampling_utils.cu 5.47 KB
Newer Older
1
2
3
4
5
6
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/sampling_utils.cu
 * @brief Sampling utility function implementations on CUDA.
 */
7
#include <thrust/for_each.h>
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <thrust/iterator/counting_iterator.h>

#include <cub/cub.cuh>

#include "./common.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

// Given rows and indptr, computes:
// inrow_indptr[i] = indptr[rows[i]];
// in_degree[i] = indptr[rows[i] + 1] - indptr[rows[i]];
template <typename indptr_t, typename nodes_t>
struct SliceFunc {
  const nodes_t* rows;
  const indptr_t* indptr;
  indptr_t* in_degree;
  indptr_t* inrow_indptr;
  __host__ __device__ auto operator()(int64_t tIdx) {
    const auto out_row = rows[tIdx];
    const auto indptr_val = indptr[out_row];
    const auto degree = indptr[out_row + 1] - indptr_val;
    in_degree[tIdx] = degree;
    inrow_indptr[tIdx] = indptr_val;
  }
};

// Returns (indptr[nodes + 1] - indptr[nodes], indptr[nodes])
std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    torch::Tensor indptr, torch::optional<torch::Tensor> nodes_optional) {
  if (nodes_optional.has_value()) {
    auto nodes = nodes_optional.value();
    const int64_t num_nodes = nodes.size(0);
    // Read indptr only once in case it is pinned and access is slow.
    auto sliced_indptr =
        torch::empty(num_nodes, nodes.options().dtype(indptr.scalar_type()));
    // compute in-degrees
    auto in_degree = torch::empty(
        num_nodes + 1, nodes.options().dtype(indptr.scalar_type()));
    thrust::counting_iterator<int64_t> iota(0);
    AT_DISPATCH_INTEGRAL_TYPES(
        indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
          using indptr_t = scalar_t;
          AT_DISPATCH_INDEX_TYPES(
              nodes.scalar_type(), "IndexSelectCSCNodes", ([&] {
                using nodes_t = index_t;
                THRUST_CALL(
                    for_each, iota, iota + num_nodes,
                    SliceFunc<indptr_t, nodes_t>{
                        nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
                        in_degree.data_ptr<indptr_t>(),
                        sliced_indptr.data_ptr<indptr_t>()});
              }));
        }));
    return {in_degree, sliced_indptr};
  } else {
    const int64_t num_nodes = indptr.size(0) - 1;
    auto sliced_indptr = indptr.slice(0, 0, num_nodes);
    auto in_degree = torch::empty(
        num_nodes + 2, indptr.options().dtype(indptr.scalar_type()));
    AT_DISPATCH_INTEGRAL_TYPES(
        indptr.scalar_type(), "IndexSelectCSCIndptr", ([&] {
          using indptr_t = scalar_t;
          CUB_CALL(
              DeviceAdjacentDifference::SubtractLeftCopy,
              indptr.data_ptr<indptr_t>(), in_degree.data_ptr<indptr_t>(),
              num_nodes + 1, cub::Difference{});
        }));
    in_degree = in_degree.slice(0, 1);
    return {in_degree, sliced_indptr};
  }
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
}

template <typename indptr_t, typename etype_t>
struct EdgeTypeSearch {
  const indptr_t* sub_indptr;
  const indptr_t* sliced_indptr;
  const etype_t* etypes;
  int64_t num_fanouts;
  int64_t num_rows;
  indptr_t* new_sub_indptr;
  indptr_t* new_sliced_indptr;
  __host__ __device__ auto operator()(int64_t i) {
    const auto homo_i = i / num_fanouts;
    const auto indptr_i = sub_indptr[homo_i];
    const auto degree = sub_indptr[homo_i + 1] - indptr_i;
    const etype_t etype = i % num_fanouts;
    auto offset = cuda::LowerBound(etypes + indptr_i, degree, etype);
    new_sub_indptr[i] = indptr_i + offset;
    new_sliced_indptr[i] = sliced_indptr[homo_i] + offset;
    if (i == num_rows - 1) new_sub_indptr[num_rows] = indptr_i + degree;
  }
};

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(
    torch::Tensor sub_indptr, torch::Tensor etypes, torch::Tensor sliced_indptr,
    int64_t num_fanouts) {
  auto num_rows = (sub_indptr.size(0) - 1) * num_fanouts;
  auto new_sub_indptr = torch::empty(num_rows + 1, sub_indptr.options());
  auto new_indegree = torch::empty(num_rows + 2, sub_indptr.options());
  auto new_sliced_indptr = torch::empty(num_rows, sliced_indptr.options());
  thrust::counting_iterator<int64_t> iota(0);
  AT_DISPATCH_INTEGRAL_TYPES(
      sub_indptr.scalar_type(), "SliceCSCIndptrHeteroIndptr", ([&] {
        using indptr_t = scalar_t;
        AT_DISPATCH_INTEGRAL_TYPES(
            etypes.scalar_type(), "SliceCSCIndptrHeteroTypePerEdge", ([&] {
              using etype_t = scalar_t;
117
118
              THRUST_CALL(
                  for_each, iota, iota + num_rows,
119
120
121
122
123
124
125
                  EdgeTypeSearch<indptr_t, etype_t>{
                      sub_indptr.data_ptr<indptr_t>(),
                      sliced_indptr.data_ptr<indptr_t>(),
                      etypes.data_ptr<etype_t>(), num_fanouts, num_rows,
                      new_sub_indptr.data_ptr<indptr_t>(),
                      new_sliced_indptr.data_ptr<indptr_t>()});
            }));
126
127
        CUB_CALL(
            DeviceAdjacentDifference::SubtractLeftCopy,
128
            new_sub_indptr.data_ptr<indptr_t>(),
129
            new_indegree.data_ptr<indptr_t>(), num_rows + 1, cub::Difference{});
130
131
132
133
134
135
136
137
138
139
140
      }));
  // Discard the first element of the SubtractLeftCopy result and ensure that
  // new_indegree tensor has size num_rows + 1 so that its ExclusiveCumSum is
  // directly equivalent to new_sub_indptr.
  // Equivalent to new_indegree = new_indegree[1:] in Python.
  new_indegree = new_indegree.slice(0, 1);
  return {new_sub_indptr, new_indegree, new_sliced_indptr};
}

}  //  namespace ops
}  //  namespace graphbolt