segment_reduce.cuh 9.36 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
2
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
3
/**
4
 *  Copyright (c) 2020 by Contributors
5
6
 * @file array/cuda/segment_reduce.cuh
 * @brief Segment reduce kernel function header.
7
 */
8
9
#ifndef DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_
#define DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_
10

11
12
#include <string>
#include <vector>
13

14
#include "../../runtime/cuda/cuda_common.h"
sangwzh's avatar
sangwzh committed
15
16
#include "atomic.cuh"
#include "utils.h"
17
18
19
20

namespace dgl {

using namespace cuda;
21
using namespace runtime;
22
23
24
25

namespace aten {
namespace cuda {

26
/**
27
28
 * @brief CUDA kernel of segment reduce.
 * @note each blockthread is responsible for aggregation on a row
29
 *       in the result tensor.
30
 */
31
template <typename IdType, typename DType, typename ReduceOp>
32
__global__ void SegmentReduceKernel(
33
    const DType* feat, const IdType* offsets, DType* out, IdType* arg,
34
    int64_t n, int64_t dim) {
35
36
37
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
38
      typename accum_dtype<DType>::type local_accum = ReduceOp::zero();
39
40
41
42
      IdType local_arg = -1;
      for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) {
        ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);
      }
43
      out[row * dim + col] = static_cast<DType>(local_accum);
44
      if (ReduceOp::require_arg) arg[row * dim + col] = local_arg;
45
46
47
48
49
      col += gridDim.y * blockDim.x;
    }
  }
}

50
/**
51
52
 * @brief CUDA kernel of scatter add.
 * @note each blockthread is responsible for adding a row in feature tensor
53
54
55
56
 *       to a target row in output tensor.
 */
template <typename IdType, typename DType>
__global__ void ScatterAddKernel(
57
    const DType* feat, const IdType* idx, DType* out, int64_t n, int64_t dim) {
58
59
60
61
62
63
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    const int write_row = idx[row];
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
      cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
      col += gridDim.y * blockDim.x;
64
65
66
67
    }
  }
}

68
/**
69
70
 * @brief CUDA kernel to update gradients for reduce op max/min
 * @note each WARP (group of 32 threads) is responsible for adding a row in
71
72
73
74
75
 * feature tensor to a target row in output tensor.
 */

template <typename IdType, typename DType>
__global__ void UpdateGradMinMaxHeteroKernel(
76
    const DType* feat, const IdType* idx, const IdType* idx_type, DType* out,
77
78
79
80
81
82
83
84
85
    int64_t n, int64_t dim, int type) {
  unsigned int tId = threadIdx.x;
  unsigned int laneId = tId & 31;
  unsigned int gId = blockIdx.x * blockDim.x + threadIdx.x;
  unsigned int warpId = gId >> 5;
  unsigned int warp_size = 32;
  unsigned int row = warpId;

  while (row < n) {
86
    for (unsigned int col = laneId; col < dim; col += warp_size) {
87
88
89
90
91
92
93
94
95
      if (type == idx_type[row * dim + col]) {
        const int write_row = idx[row * dim + col];
        cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
      }
    }
    row += blockDim.x * gridDim.x;
  }
}

96
/**
97
98
 * @brief CUDA kernel of backward phase in segment min/max.
 * @note each blockthread is responsible for writing a row in the
99
 *       result gradient tensor by lookup the ArgMin/Max for index information.
100
101
102
 */
template <typename IdType, typename DType>
__global__ void BackwardSegmentCmpKernel(
103
    const DType* feat, const IdType* arg, DType* out, int64_t n, int64_t dim) {
104
105
106
107
108
109
110
111
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
      int write_row = arg[row * dim + col];
      if (write_row >= 0) {
        out[write_row * dim + col] = feat[row * dim + col];
      }
      col += gridDim.y * blockDim.x;
112
113
114
115
    }
  }
}

116
/**
117
118
119
120
121
 * @brief CUDA implementation of forward phase of Segment Reduce.
 * @param feat The input tensor.
 * @param offsets The offsets tensor.
 * @param out The output tensor.
 * @param arg An auxiliary tensor storing ArgMax/Min information,
122
 */
123
template <typename IdType, typename DType, typename ReduceOp>
124
void SegmentReduce(NDArray feat, NDArray offsets, NDArray out, NDArray arg) {
125
126
127
128
129
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
  IdType* arg_data = arg.Ptr<IdType>();

sangwzh's avatar
sangwzh committed
130
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
131
132
  int64_t n = out->shape[0];
  int64_t dim = 1;
133
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
134

135
  const int nbx = FindNumBlocks<'x'>(n);
136
  const int ntx = FindNumThreads(dim);
137
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
138
139
140
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
141
  // TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
142
143
144
  CUDA_KERNEL_CALL(
      (SegmentReduceKernel<IdType, DType, ReduceOp>), nblks, nthrs, 0, stream,
      feat_data, offsets_data, out_data, arg_data, n, dim);
145
146
}

147
/**
148
149
150
151
152
 * @brief CUDA implementation of Scatter Add (on first dimension).
 * @note math equation: out[idx[i], *] += feat[i, *]
 * @param feat The input tensor.
 * @param idx The indices tensor.
 * @param out The output tensor.
153
154
 */
template <typename IdType, typename DType>
155
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
156
157
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* idx_data = idx.Ptr<IdType>();
158
  DType* out_data = out.Ptr<DType>();
159

sangwzh's avatar
sangwzh committed
160
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
161
162
  int64_t n = feat->shape[0];
  int64_t dim = 1;
163
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
164
165
166
167
168
169
170

  const int nbx = FindNumBlocks<'x'>(n);
  const int ntx = FindNumThreads(dim);
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
171
172
173
  CUDA_KERNEL_CALL(
      (ScatterAddKernel<IdType, DType>), nblks, nthrs, 0, stream, feat_data,
      idx_data, out_data, n, dim);
174
175
}

176
/**
177
178
179
180
181
182
183
 * @brief CUDA implementation to update gradients for reduce op max/min
 * @param graph The input heterogeneous graph.
 * @param op The binary operator, could be `copy_u`, `copy_e'.
 * @param list_feat List of the input tensors.
 * @param list_idx  List of the indices tensors.
 * @param list_idx_etype List of the node- or edge-type tensors.
 * @param list_out List of the output tensors.
184
185
 */
template <typename IdType, typename DType>
186
187
188
189
190
void UpdateGradMinMax_hetero(
    const HeteroGraphPtr& graph, const std::string& op,
    const std::vector<NDArray>& list_feat, const std::vector<NDArray>& list_idx,
    const std::vector<NDArray>& list_idx_types,
    std::vector<NDArray>* list_out) {
sangwzh's avatar
sangwzh committed
191
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
192
  if (op == "copy_lhs" || op == "copy_rhs") {
193
194
    std::vector<std::vector<dgl_id_t>> src_dst_ntypes(
        graph->NumVertexTypes(), std::vector<dgl_id_t>());
195
196
197
198
    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
      auto pair = graph->meta_graph()->FindEdge(etype);
      const dgl_id_t dst_ntype = pair.first;  // graph is reversed
      const dgl_id_t src_ntype = pair.second;
199
200
201
202
203
204
205
      auto same_src_dst_ntype = std::find(
          std::begin(src_dst_ntypes[dst_ntype]),
          std::end(src_dst_ntypes[dst_ntype]), src_ntype);
      // if op is "copy_lhs", relation type with same src and dst node type will
      // be updated once
      if (op == "copy_lhs" &&
          same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        continue;
      src_dst_ntypes[dst_ntype].push_back(src_ntype);
      const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
      const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();
      const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();
      int type = (op == "copy_lhs") ? src_ntype : etype;
      DType* out_data = (*list_out)[type].Ptr<DType>();
      int dim = 1;
      for (int i = 1; i < (*list_out)[type]->ndim; ++i)
        dim *= (*list_out)[type]->shape[i];
      int n = list_feat[dst_ntype]->shape[0];
      const int th_per_row = 32;
      const int ntx = 128;
      const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx);
      const dim3 nblks(nbx);
      const dim3 nthrs(ntx);
222
223
224
      CUDA_KERNEL_CALL(
          (UpdateGradMinMaxHeteroKernel<IdType, DType>), nblks, nthrs, 0,
          stream, feat_data, idx_data, idx_type_data, out_data, n, dim, type);
225
226
227
228
    }
  }
}

229
/**
230
 * @brief CUDA implementation of backward phase of Segment Reduce with Min/Max
231
 *        reducer.
232
233
 * @note math equation: out[arg[i, k], k] = feat[i, k]
 * @param feat The input
234
 *       tensor.
235
236
 * @param arg The ArgMin/Max information, used for indexing.
 * @param out The output tensor.
237
 */
238
template <typename IdType, typename DType>
239
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
240
241
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* arg_data = arg.Ptr<IdType>();
242
  DType* out_data = out.Ptr<DType>();
243

sangwzh's avatar
sangwzh committed
244
  hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
245
246
  int64_t n = feat->shape[0];
  int64_t dim = 1;
247
  for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
248

249
  const int nbx = FindNumBlocks<'x'>(n);
250
  const int ntx = FindNumThreads(dim);
251
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
252
253
254
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
255
256
257
  CUDA_KERNEL_CALL(
      (BackwardSegmentCmpKernel<IdType, DType>), nblks, nthrs, 0, stream,
      feat_data, arg_data, out_data, n, dim);
258
259
260
261
262
263
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

264
#endif  // DGL_ARRAY_CUDA_SEGMENT_REDUCE_CUH_