functor.cuh 14.7 KB
Newer Older
1
/**
2
 *  Copyright (c) 2020 by Contributors
3
4
 * @file array/cuda/functor.cuh
 * @brief Functors for template on CUDA
5
6
7
8
 */
#ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_
#define DGL_ARRAY_CUDA_FUNCTOR_CUH_

9
10
#include <cmath>
#include <limits>
11

12
#include "./atomic.cuh"
13
#include "./fp16.cuh"
14
#include "bf16.cuh"
15

16
17
18
19
namespace dgl {
namespace aten {
namespace cuda {

20
/////////////////////////// CUDA binary operators //////////////////////////////
21
22
23
24
25
26
namespace binary {
template <typename DType>
struct Add {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
27
28
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
29
30
31
    return lhs[0] + rhs[0];
  }
};
32
33
34
35
36
37
template <typename DType>
constexpr bool Add<DType>::use_lhs;
template <typename DType>
constexpr bool Add<DType>::use_rhs;
template <typename DType>
constexpr bool Add<DType>::reduce_last_dim;
38
39
40
41
42
43

template <typename DType>
struct Sub {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
44
45
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
46
47
48
    return lhs[0] - rhs[0];
  }
};
49
50
51
52
53
54
template <typename DType>
constexpr bool Sub<DType>::use_lhs;
template <typename DType>
constexpr bool Sub<DType>::use_rhs;
template <typename DType>
constexpr bool Sub<DType>::reduce_last_dim;
55
56
57
58
59
60

template <typename DType>
struct Mul {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
61
62
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
63
64
65
    return lhs[0] * rhs[0];
  }
};
66
67
68
69
70
71
template <typename DType>
constexpr bool Mul<DType>::use_lhs;
template <typename DType>
constexpr bool Mul<DType>::use_rhs;
template <typename DType>
constexpr bool Mul<DType>::reduce_last_dim;
72
73
74
75
76
77

template <typename DType>
struct Div {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
78
79
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
80
81
82
    return lhs[0] / rhs[0];
  }
};
83
84
85
86
87
88
template <typename DType>
constexpr bool Div<DType>::use_lhs;
template <typename DType>
constexpr bool Div<DType>::use_rhs;
template <typename DType>
constexpr bool Div<DType>::reduce_last_dim;
89
90

template <typename DType>
91
struct CopyLhs {
92
93
94
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = false;
  static constexpr bool reduce_last_dim = false;
95
96
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
97
98
99
    return lhs[0];
  }
};
100
101
102
103
104
105
template <typename DType>
constexpr bool CopyLhs<DType>::use_lhs;
template <typename DType>
constexpr bool CopyLhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyLhs<DType>::reduce_last_dim;
106
107

template <typename DType>
108
struct CopyRhs {
109
110
111
  static constexpr bool use_lhs = false;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
112
113
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
114
115
116
    return rhs[0];
  }
};
117
118
119
120
121
122
template <typename DType>
constexpr bool CopyRhs<DType>::use_lhs;
template <typename DType>
constexpr bool CopyRhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyRhs<DType>::reduce_last_dim;
123
124
125
126
127
128

template <typename DType>
struct Dot {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = true;
129
130
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
131
    DType rst = static_cast<DType>(0.0f);
132
133
134
135
136
137
    for (int64_t i = 0; i < len; ++i) {
      rst += lhs[i] * rhs[i];
    }
    return rst;
  }
};
138
139
140
141
142
143
template <typename DType>
constexpr bool Dot<DType>::use_lhs;
template <typename DType>
constexpr bool Dot<DType>::use_rhs;
template <typename DType>
constexpr bool Dot<DType>::reduce_last_dim;
144

145
}  // end of namespace binary
146

147
/////////////////////////// CUDA reduce operators //////////////////////////////
148
namespace reduce {
149
template <typename Idx, typename DType, bool atomic>
150
151
152
struct _Sum {
  static constexpr __host__ __device__ __forceinline__ DType zero() {
    return 0.;
153
  }
154
155
  static constexpr bool require_arg = false;
  static __device__ __forceinline__ void Call(
156
157
      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
      Idx eid) {
158
159
160
161
162
163
    if (!atomic) {
      *out_buf += val;
    } else {
      cuda::AtomicAdd(out_buf, val);
    }
  }
164
  static __device__ __forceinline__ void Call(
165
      DType *out_buf, Idx *arg_buf, DType val, Idx id) {
166
167
168
169
170
171
    if (!atomic) {
      *out_buf += val;
    } else {
      cuda::AtomicAdd(out_buf, val);
    }
  }
172
173
174
  static __device__ __forceinline__ void CallArg(
      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
      Idx uid, Idx eid) {}
175
176
};

177
178
template <typename Idx, typename DType, bool atomic = false>
struct Sum : _Sum<Idx, DType, atomic> {};
179
180

template <typename Idx, bool atomic>
181
182
struct Sum<Idx, __half, atomic> : _Sum<Idx, __half, atomic> {
  static constexpr __host__ __device__ __forceinline__ __half zero() {
183
    return __float2half_rn(0.);
184
  }
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Sum<Idx, __half, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Sum<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Sum<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
207
};
208
209
210

#if BF16_ENABLED
template <typename Idx, bool atomic>
211
struct Sum<Idx, __nv_bfloat16, atomic> : _Sum<Idx, __nv_bfloat16, atomic> {
212
213
214
  static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
    return __float2bfloat16_rn(0.);
  }
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
  static __device__ __forceinline__ void Call(
      __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __nv_bfloat16 val, Idx uid, Idx eid) {
    _Sum<Idx, __nv_bfloat16, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
    _Sum<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __nv_bfloat16 val, Idx uid, Idx eid) {
    _Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
    _Sum<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
237
238
};
#endif  // BF16_ENABLED
239

240
template <typename Idx, typename DType, bool atomic>
241
242
243
struct _Max {
  static constexpr __host__ __device__ __forceinline__ DType zero() {
    return -std::numeric_limits<DType>::infinity();
244
  }
245
246
  static constexpr bool require_arg = true;
  static __device__ __forceinline__ void Call(
247
248
      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
      Idx eid) {
249
250
251
252
253
254
255
256
257
258
    if (!atomic) {
      if (*out_buf < val) {
        *out_buf = val;
        *arg_u_buf = uid;
        *arg_e_buf = eid;
      }
    } else {
      cuda::AtomicMax(out_buf, val);
    }
  }
259
  static __device__ __forceinline__ void Call(
260
      DType *out_buf, Idx *arg_buf, DType val, Idx id) {
261
262
263
264
265
266
267
268
269
    if (!atomic) {
      if (*out_buf < val) {
        *out_buf = val;
        *arg_buf = id;
      }
    } else {
      cuda::AtomicMax(out_buf, val);
    }
  }
270
271
272
  static __device__ __forceinline__ void CallArg(
      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
      Idx uid, Idx eid) {
273
274
    if (atomic) {
      if (val == val_ref) {
275
276
        if (arg_u_buf) arg_u_buf[fid] = uid;
        if (arg_e_buf) arg_e_buf[fid] = eid;
277
278
279
280
281
      }
    }
  }
};

282
283
template <typename Idx, typename DType, bool atomic = false>
struct Max : _Max<Idx, DType, atomic> {};
284

285
template <typename Idx, bool atomic>
286
287
struct Max<Idx, __half, atomic> : _Max<Idx, __half, atomic> {
  static constexpr __host__ __device__ __forceinline__ __half zero() {
288
    return __float2half_rn(-6.550400e+04f);
289
  }
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Max<Idx, __half, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Max<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Max<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
312
};
313
314

#if BF16_ENABLED
315
template <typename Idx, bool atomic>
316
317
318
319
struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
  static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
    return __float2bfloat16_rn(-std::numeric_limits<float>::infinity());
  }
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
  static __device__ __forceinline__ void Call(
      __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __nv_bfloat16 val, Idx uid, Idx eid) {
    _Max<Idx, __nv_bfloat16, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
    _Max<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __nv_bfloat16 val, Idx uid, Idx eid) {
    _Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
    _Max<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
342
343
};
#endif  // BF16_ENABLED
344

345
template <typename Idx, typename DType, bool atomic>
346
347
348
struct _Min {
  static constexpr __host__ __device__ __forceinline__ DType zero() {
    return std::numeric_limits<DType>::infinity();
349
  }
350
351
  static constexpr bool require_arg = true;
  static __device__ __forceinline__ void Call(
352
353
      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
      Idx eid) {
354
355
356
357
358
359
360
361
362
363
    if (!atomic) {
      if (*out_buf > val) {
        *out_buf = val;
        *arg_u_buf = uid;
        *arg_e_buf = eid;
      }
    } else {
      cuda::AtomicMin(out_buf, val);
    }
  }
364
  static __device__ __forceinline__ void Call(
365
      DType *out_buf, Idx *arg_buf, DType val, Idx id) {
366
367
368
369
370
371
372
373
374
    if (!atomic) {
      if (*out_buf > val) {
        *out_buf = val;
        *arg_buf = id;
      }
    } else {
      cuda::AtomicMin(out_buf, val);
    }
  }
375
376
377
  static __device__ __forceinline__ void CallArg(
      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
      Idx uid, Idx eid) {
378
379
    if (atomic) {
      if (val == val_ref) {
380
381
        if (arg_u_buf) arg_u_buf[fid] = uid;
        if (arg_e_buf) arg_e_buf[fid] = eid;
382
383
384
385
      }
    }
  }
};
386

387
388
template <typename Idx, typename DType, bool atomic = false>
struct Min : _Min<Idx, DType, atomic> {};
389

390
template <typename Idx, bool atomic>
391
392
struct Min<Idx, __half, atomic> : _Min<Idx, __half, atomic> {
  static constexpr __host__ __device__ __forceinline__ __half zero() {
393
    return __float2half_rn(6.550400e+04f);
394
  }
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Min<Idx, __half, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Min<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Min<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
417
};
418
419

#if BF16_ENABLED
420
template <typename Idx, bool atomic>
421
422
423
424
struct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> {
  static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
    return __float2bfloat16_rn(std::numeric_limits<float>::infinity());
  }
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
  static __device__ __forceinline__ void Call(
      __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __nv_bfloat16 val, Idx uid, Idx eid) {
    _Min<Idx, __nv_bfloat16, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
    _Min<Idx, __nv_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __nv_bfloat16 val, Idx uid, Idx eid) {
    _Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) {
    _Min<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
447
448
};
#endif  // BF16_ENABLED
449
450
451
452
453
454
455
456

}  // namespace reduce

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CUDA_FUNCTOR_CUH_