"src/vscode:/vscode.git/clone" did not exist on "f7cd6b87e1ee8c7909de760f22f1a6b0c6ae0592"
segment_reduce.cuh 9.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 *  Copyright (c) 2020 by Contributors
 * \file array/cuda/segment_reduce.cuh
 * \brief Segment reduce kernel function header.
 */
#ifndef DGL_ARRAY_SEGMENT_REDUCE_CUH_
#define DGL_ARRAY_SEGMENT_REDUCE_CUH_

#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
11
#include "./atomic.cuh"
12
13
14
15
16
17
18
19
20
21

namespace dgl {

using namespace cuda;

namespace aten {
namespace cuda {

/*!
 * \brief CUDA kernel of segment reduce.
22
23
 * \note each blockthread is responsible for aggregation on a row
 *       in the result tensor.
24
25
26
27
28
29
30
 */
template <typename IdType, typename DType,
          typename ReduceOp>
__global__ void SegmentReduceKernel(
    const DType* feat, const IdType* offsets,
    DType* out, IdType* arg,
    int64_t n, int64_t dim){
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
      DType local_accum = ReduceOp::zero();
      IdType local_arg = -1;
      for (IdType i = offsets[row]; i < offsets[row + 1]; ++i) {
        ReduceOp::Call(&local_accum, &local_arg, feat[i * dim + col], i);
      }
      out[row * dim + col] = local_accum;
      if (ReduceOp::require_arg)
        arg[row * dim + col] = local_arg;
      col += gridDim.y * blockDim.x;
    }
  }
}

/*!
 * \brief CUDA kernel of scatter add.
 * \note each blockthread is responsible for adding a row in feature tensor
 *       to a target row in output tensor.
 */
template <typename IdType, typename DType>
__global__ void ScatterAddKernel(
    const DType *feat, const IdType *idx, DType *out,
    int64_t n, int64_t dim) {
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    const int write_row = idx[row];
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
      cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
      col += gridDim.y * blockDim.x;
62
63
64
65
    }
  }
}

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/*!
 * \brief CUDA kernel to update gradients for reduce op max/min
 * \note each WARP (group of 32 threads) is responsible for adding a row in
 * feature tensor to a target row in output tensor.
 */

template <typename IdType, typename DType>
__global__ void UpdateGradMinMaxHeteroKernel(
    const DType *feat, const IdType *idx, const IdType *idx_type, DType *out,
    int64_t n, int64_t dim, int type) {
  unsigned int tId = threadIdx.x;
  unsigned int laneId = tId & 31;
  unsigned int gId = blockIdx.x * blockDim.x + threadIdx.x;
  unsigned int warpId = gId >> 5;
  unsigned int warp_size = 32;
  unsigned int row = warpId;

  while (row < n) {
    for(unsigned int col = laneId; col < dim; col += warp_size) {
      if (type == idx_type[row * dim + col]) {
        const int write_row = idx[row * dim + col];
        cuda::AtomicAdd(out + write_row * dim + col, feat[row * dim + col]);
      }
    }
    row += blockDim.x * gridDim.x;
  }
}

94
/*!
95
96
97
 * \brief CUDA kernel of backward phase in segment min/max.
 * \note each blockthread is responsible for writing a row in the
 *       result gradient tensor by lookup the ArgMin/Max for index information.
98
99
100
101
102
 */
template <typename IdType, typename DType>
__global__ void BackwardSegmentCmpKernel(
    const DType *feat, const IdType *arg, DType *out,
    int64_t n, int64_t dim) {
103
104
105
106
107
108
109
110
  for (int row = blockIdx.x; row < n; row += gridDim.x) {
    int col = blockIdx.y * blockDim.x + threadIdx.x;
    while (col < dim) {
      int write_row = arg[row * dim + col];
      if (write_row >= 0) {
        out[write_row * dim + col] = feat[row * dim + col];
      }
      col += gridDim.y * blockDim.x;
111
112
113
114
    }
  }
}

115
116
117
118
119
120
121
/*!
 * \brief CUDA implementation of forward phase of Segment Reduce.
 * \param feat The input tensor.
 * \param offsets The offsets tensor.
 * \param out The output tensor.
 * \param arg An auxiliary tensor storing ArgMax/Min information,
 */
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
template <typename IdType, typename DType, typename ReduceOp>
void SegmentReduce(
    NDArray feat,
    NDArray offsets,
    NDArray out,
    NDArray arg) {
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* offsets_data = offsets.Ptr<IdType>();
  DType* out_data = out.Ptr<DType>();
  IdType* arg_data = arg.Ptr<IdType>();

  auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  int64_t n = out->shape[0];
  int64_t dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];

139
  const int nbx = FindNumBlocks<'x'>(n);
140
  const int ntx = FindNumThreads(dim);
141
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
142
143
144
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
145
  // TODO(zihao): try cub's DeviceSegmentedReduce and compare the performance.
146
147
148
149
150
151
  CUDA_KERNEL_CALL((SegmentReduceKernel<IdType, DType, ReduceOp>),
      nblks, nthrs, 0, thr_entry->stream,
      feat_data, offsets_data, out_data, arg_data,
      n, dim);
}

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
/*!
 * \brief CUDA implementation of Scatter Add (on first dimension).
 * \note math equation: out[idx[i], *] += feat[i, *]
 * \param feat The input tensor.
 * \param idx The indices tensor.
 * \param out The output tensor.
 */
template <typename IdType, typename DType>
void ScatterAdd(
    NDArray feat,
    NDArray idx,
    NDArray out) {
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* idx_data = idx.Ptr<IdType>();
  DType *out_data = out.Ptr<DType>();
  
  auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  int64_t n = feat->shape[0];
  int64_t dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];

  const int nbx = FindNumBlocks<'x'>(n);
  const int ntx = FindNumThreads(dim);
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  CUDA_KERNEL_CALL((ScatterAddKernel<IdType, DType>),
                   nblks, nthrs, 0, thr_entry->stream,
                   feat_data, idx_data, out_data,
                   n, dim);
}

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
/*!
 * \brief CUDA implementation to update gradients for reduce op max/min
 * \param graph The input heterogeneous graph.
 * \param op The binary operator, could be `copy_u`, `copy_e'.
 * \param list_feat List of the input tensors.
 * \param list_idx  List of the indices tensors.
 * \param list_idx_etype List of the node- or edge-type tensors.
 * \param list_out List of the output tensors.
 */
template <typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& graph,
                const std::string& op,
                const std::vector<NDArray>& list_feat,
                const std::vector<NDArray>& list_idx,
                const std::vector<NDArray>& list_idx_types,
                std::vector<NDArray>* list_out) {
  if (op == "copy_lhs" || op == "copy_rhs") {
    std::vector<std::vector<dgl_id_t>> src_dst_ntypes(graph->NumVertexTypes(),
    std::vector<dgl_id_t>());
    for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
      auto pair = graph->meta_graph()->FindEdge(etype);
      const dgl_id_t dst_ntype = pair.first;  // graph is reversed
      const dgl_id_t src_ntype = pair.second;
      auto same_src_dst_ntype = std::find(std::begin(src_dst_ntypes[dst_ntype]),
        std::end(src_dst_ntypes[dst_ntype]), src_ntype);
      // if op is "copy_lhs", relation type with same src and dst node type will be updated once
      if (op == "copy_lhs" && same_src_dst_ntype != std::end(src_dst_ntypes[dst_ntype]))
        continue;
      src_dst_ntypes[dst_ntype].push_back(src_ntype);
      const DType* feat_data = list_feat[dst_ntype].Ptr<DType>();
      const IdType* idx_data = list_idx[dst_ntype].Ptr<IdType>();
      const IdType* idx_type_data = list_idx_types[dst_ntype].Ptr<IdType>();
      int type = (op == "copy_lhs") ? src_ntype : etype;
      DType* out_data = (*list_out)[type].Ptr<DType>();
      int dim = 1;
      for (int i = 1; i < (*list_out)[type]->ndim; ++i)
        dim *= (*list_out)[type]->shape[i];
      int n = list_feat[dst_ntype]->shape[0];
      auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
      const int th_per_row = 32;
      const int ntx = 128;
      const int nbx = FindNumBlocks<'x'>((n * th_per_row + ntx - 1) / ntx);
      const dim3 nblks(nbx);
      const dim3 nthrs(ntx);
      CUDA_KERNEL_CALL((UpdateGradMinMaxHeteroKernel<IdType, DType>),
                       nblks, nthrs, 0, thr_entry->stream,
                       feat_data, idx_data, idx_type_data,
                       out_data, n, dim, type);
    }
  }
}

238
239
/*!
 * \brief CUDA implementation of backward phase of Segment Reduce with Min/Max reducer.
240
 * \note math equation: out[arg[i, k], k] = feat[i, k]
241
242
243
244
 * \param feat The input tensor.
 * \param arg The ArgMin/Max information, used for indexing.
 * \param out The output tensor.
 */
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
template <typename IdType, typename DType>
void BackwardSegmentCmp(
    NDArray feat,
    NDArray arg,
    NDArray out) {
  const DType* feat_data = feat.Ptr<DType>();
  const IdType* arg_data = arg.Ptr<IdType>();
  DType *out_data = out.Ptr<DType>();

  auto *thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
  int64_t n = feat->shape[0];
  int64_t dim = 1;
  for (int i = 1; i < out->ndim; ++i)
    dim *= out->shape[i];

260
  const int nbx = FindNumBlocks<'x'>(n);
261
  const int ntx = FindNumThreads(dim);
262
  const int nby = FindNumBlocks<'y'>((dim + ntx - 1) / ntx);
263
264
265
266
267
268
269
270
271
272
273
274
275
276
  const int nty = 1;
  const dim3 nblks(nbx, nby);
  const dim3 nthrs(ntx, nty);
  CUDA_KERNEL_CALL((BackwardSegmentCmpKernel<IdType, DType>),
                   nblks, nthrs, 0, thr_entry->stream,
                   feat_data, arg_data, out_data,
                   n, dim);
}

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif