segment_reduce.cuh 9.21 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/segment_reduce.cuh
 * @brief Segment reduce kernel function header.
5
 */
6
7
#ifndef DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_
#define DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_
8

9
10
#include <string>
#include <vector>
11

12
#include "../../runtime/cuda/cuda_common.h"
13
#include "./atomic.cuh"
14
#include "./utils.h"
15
16
17
18

namespace dgl {

using namespace cuda;
19
using namespace runtime;
20
21
22
23

namespace aten {
namespace cuda {

24
/**
25
26
 * @brief CUDA kernel of segment reduce.
 * @note each blockthread is responsible for aggregation on a row
27
 *       in the result tensor.
28
 */
29
template <typename IdType, typename DType, typename ReduceOp>
30
__global__ void SegmentReduceKernel(
31
    const DType* feat, const IdType* offsets, DType* out, IdType* arg,
32
    int64_t n, int64_t dim) {
33
34
35
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
36
      typename accum_dtype<DType>::type local_accum = ReduceOp::zero();
37
38
39
40
      IdType local_arg = -1;
      for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) {
        ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);
      }
41
      out[row * dim + col] = static_cast<DType>(local_accum);
42
      if (ReduceOp::require_arg) arg[row * dim + col] = local_arg;
43
44
45
46
47
      col += gridDim.y * blockDim.x;
    }
  }
}

48
/**
49
50
 * @brief CUDA kernel of scatter add.
 * @note each blockthread is responsible for adding a row in feature tensor
51
52
53
54
 *       to a target row in output tensor.
 */
template <typename IdType, typename DType>
__global__ void ScatterAddKernel(
55
    const DType* feat, const IdType* idx, DType* out, int64_t n, int64_t dim) {
56
57
58
59
60
61
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    const int write_row = idx[row];
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
      cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
      col += gridDim.y * blockDim.x;
62
63
64
65
    }
  }
}

66
/**
67
68
 * @brief CUDA kernel to update gradients for reduce op max/min
 * @note each WARP (group of 32 threads) is responsible for adding a row in
69
70
71
72
73
 * feature tensor to a target row in output tensor.
 */

template <typename IdType, typename DType>
__global__ void UpdateGradMinMaxHeteroKernel(
74
    const DType* feat, const IdType* idx, const IdType* idx_type, DType* out,
75
76
77
78
79
80
81
82
83
    int64_t n, int64_t dim, int type) {
  unsigned int tId = threadIdx.x;
  unsigned int laneId = tId & 31;
  unsigned int gId = blockIdx.x * blockDim.x + threadIdx.x;
  unsigned int warpId = gId >> 5;
  unsigned int warp_size = 32;
  unsigned int row = warpId;

  while (row < n) {
84
    for (unsigned int col = laneId; col < dim; col += warp_size) {
85
86
87
88
89
90
91
92
93
      if (type == idx_type[row * dim + col]) {
        const int write_row = idx[row * dim + col];
        cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
      }
    }
    row += blockDim.x * gridDim.x;
  }
}

94
/**
95
96
 * @brief CUDA kernel of backward phase in segment min/max.
 * @note each blockthread is responsible for writing a row in the
97
 *       result gradient tensor by lookup the ArgMin/Max for index information.
98
99
100
 */
template <typename IdType, typename DType>
__global__ void BackwardSegmentCmpKernel(
101
    const DType* feat, const IdType* arg, DType* out, int64_t n, int64_t dim) {
102
103
104
105
106
107
108
109
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
      int write_row = arg[row * dim + col];
      if (write_row >= 0) {
        out[write_row * dim + col] = feat[row * dim + col];
      }
      col += gridDim.y * blockDim.x;
110
111
112
113
    }
  }
}

114
/**
115
116
117
118
119
 * @brief CUDA implementation of forward phase of Segment Reduce.
 * @param feat The input tensor.
 * @param offsets The offsets tensor.
 * @param out The output tensor.
 * @param arg An auxiliary tensor storing ArgMax/Min information,
120
 */
121
template <typename IdType, typename DType, typename ReduceOp>
122
void SegmentReduce(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {
123
124
125
126
127
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
  IdType* arg_data = arg.Ptr<IdType>();

128
  cudaStream_t stream = runtime::getCurrentCUDAStream();
129
130
  int64_t n = out->shape[0];
  int64_t dim = 1;
131
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
132

133
  const int nbx = FindNumBlocks<'x'>(n);
134
  const int ntx = FindNumThreads(dim);
135
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
136
137
138
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
139
  // TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
140
141
142
  CUDA_KERNEL_CALL(
      (SegmentReduceKernel<IdType, DType, ReduceOp>), nblks, nthrs, 0, stream,
      feat_data, offsets_data, out_data, arg_data, n, dim);
143
144
}

145
/**
146
147
148
149
150
 * @brief CUDA implementation of Scatter Add (on first dimension).
 * @note math equation: out[idx[i], *] += feat[i, *]
 * @param feat The input tensor.
 * @param idx The indices tensor.
 * @param out The output tensor.
151
152
 */
template <typename IdType, typename DType>
153
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
154
155
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* idx_data = idx.Ptr<IdType>();
156
  DType* out_data = out.Ptr<DType>();
157
158

  cudaStream_t stream = runtime::getCurrentCUDAStream();
159
160
  int64_t n = feat->shape[0];
  int64_t dim = 1;
161
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
162
163
164
165
166
167
168

  const int nbx = FindNumBlocks<'x'>(n);
  const int ntx = FindNumThreads(dim);
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
169
170
171
  CUDA_KERNEL_CALL(
      (ScatterAddKernel<IdType, DType>), nblks, nthrs, 0, stream, feat_data,
      idx_data, out_data, n, dim);
172
173
}

174
/**
175
176
177
178
179
180
181
 * @brief CUDA implementation to update gradients for reduce op max/min
 * @param graph The input heterogeneous graph.
 * @param op The binary operator, could be `copy_u`, `copy_e'.
 * @param list_feat List of the input tensors.
 * @param list_idx  List of the indices tensors.
 * @param list_idx_etype List of the node- or edge-type tensors.
 * @param list_out List of the output tensors.
182
183
 */
template <typename IdType, typename DType>
184
185
186
187
188
void UpdateGradMinMax_hetero(
    const HeteroGraphPtr& graph, const std::string& op,
    const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,
    const std::vector<NDArray>& list_idx_types,
    std::vector<NDArray>* list_out) {
189
  cudaStream_t stream = runtime::getCurrentCUDAStream();
190
  if (op == "copy_lhs" || op == "copy_rhs") {
191
192
    std::vector<std::vector<dgl_id_t>> src_dst_ntypes(
        graph->NumVertexTypes(), std::vector<dgl_id_t>());
193
194
195
196
    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
      auto pair = graph->meta_graph()->FindEdge(etype);
      const dgl_id_t dst_ntype = pair.first;  // graph is reversed
      const dgl_id_t src_ntype = pair.second;
197
198
199
200
201
202
203
      auto same_src_dst_ntype = std::find(
          std::begin(src_dst_ntypes[dst_ntype]),
          std::end(src_dst_ntypes[dst_ntype]), src_ntype);
      // if op is "copy_lhs", relation type with same src and dst node type will
      // be updated once
      if (op == "copy_lhs" &&
          same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        continue;
      src_dst_ntypes[dst_ntype].push_back(src_ntype);
      const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
      const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();
      const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();
      int type = (op == "copy_lhs") ? src_ntype : etype;
      DType* out_data = (*list_out)[type].Ptr<DType>();
      int dim = 1;
      for (int i = 1; i < (*list_out)[type]->ndim; ++i)
        dim *= (*list_out)[type]->shape[i];
      int n = list_feat[dst_ntype]->shape[0];
      const int th_per_row = 32;
      const int ntx = 128;
      const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx);
      const dim3 nblks(nbx);
      const dim3 nthrs(ntx);
220
221
222
      CUDA_KERNEL_CALL(
          (UpdateGradMinMaxHeteroKernel<IdType, DType>), nblks, nthrs, 0,
          stream, feat_data, idx_data, idx_type_data, out_data, n, dim, type);
223
224
225
226
    }
  }
}

227
/**
228
 * @brief CUDA implementation of backward phase of Segment Reduce with Min/Max
229
 *        reducer.
230
231
 * @note math equation: out[arg[i, k], k] = feat[i, k]
 * @param feat The input
232
 *       tensor.
233
234
 * @param arg The ArgMin/Max information, used for indexing.
 * @param out The output tensor.
235
 */
236
template <typename IdType, typename DType>
237
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
238
239
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* arg_data = arg.Ptr<IdType>();
240
  DType* out_data = out.Ptr<DType>();
241

242
  cudaStream_t stream = runtime::getCurrentCUDAStream();
243
244
  int64_t n = feat->shape[0];
  int64_t dim = 1;
245
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
246

247
  const int nbx = FindNumBlocks<'x'>(n);
248
  const int ntx = FindNumThreads(dim);
249
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
250
251
252
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
253
254
255
  CUDA_KERNEL_CALL(
      (BackwardSegmentCmpKernel<IdType, DType>), nblks, nthrs, 0, stream,
      feat_data, arg_data, out_data, n, dim);
256
257
258
259
260
261
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

262
#endif  // DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_