functor.cuh 14.8 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2020 by Contributors
4
5
 * @file array/cuda/functor.cuh
 * @brief Functors for template on CUDA
6
7
8
9
 */
#ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_
#define DGL_ARRAY_CUDA_FUNCTOR_CUH_

10
11
#include <cmath>
#include <limits>
12

sangwzh's avatar
sangwzh committed
13
14
#include "atomic.cuh"
#include "fp16.cuh"
15
#include "bf16.cuh"
16

17
18
19
20
namespace dgl {
namespace aten {
namespace cuda {

21
/////////////////////////// CUDA binary operators //////////////////////////////
22
23
24
25
26
27
namespace binary {
template <typename DType>
struct Add {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
28
29
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
30
31
32
    return lhs[0] + rhs[0];
  }
};
33
34
35
36
37
38
template <typename DType>
constexpr bool Add<DType>::use_lhs;
template <typename DType>
constexpr bool Add<DType>::use_rhs;
template <typename DType>
constexpr bool Add<DType>::reduce_last_dim;
39
40
41
42
43
44

template <typename DType>
struct Sub {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
45
46
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
47
48
49
    return lhs[0] - rhs[0];
  }
};
50
51
52
53
54
55
template <typename DType>
constexpr bool Sub<DType>::use_lhs;
template <typename DType>
constexpr bool Sub<DType>::use_rhs;
template <typename DType>
constexpr bool Sub<DType>::reduce_last_dim;
56
57
58
59
60
61

template <typename DType>
struct Mul {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
62
63
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
64
65
66
    return lhs[0] * rhs[0];
  }
};
67
68
69
70
71
72
template <typename DType>
constexpr bool Mul<DType>::use_lhs;
template <typename DType>
constexpr bool Mul<DType>::use_rhs;
template <typename DType>
constexpr bool Mul<DType>::reduce_last_dim;
73
74
75
76
77
78

template <typename DType>
struct Div {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
79
80
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
81
82
83
    return lhs[0] / rhs[0];
  }
};
84
85
86
87
88
89
template <typename DType>
constexpr bool Div<DType>::use_lhs;
template <typename DType>
constexpr bool Div<DType>::use_rhs;
template <typename DType>
constexpr bool Div<DType>::reduce_last_dim;
90
91

template <typename DType>
92
struct CopyLhs {
93
94
95
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = false;
  static constexpr bool reduce_last_dim = false;
96
97
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
98
99
100
    return lhs[0];
  }
};
101
102
103
104
105
106
template <typename DType>
constexpr bool CopyLhs<DType>::use_lhs;
template <typename DType>
constexpr bool CopyLhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyLhs<DType>::reduce_last_dim;
107
108

template <typename DType>
109
struct CopyRhs {
110
111
112
  static constexpr bool use_lhs = false;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = false;
113
114
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
115
116
117
    return rhs[0];
  }
};
118
119
120
121
122
123
template <typename DType>
constexpr bool CopyRhs<DType>::use_lhs;
template <typename DType>
constexpr bool CopyRhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyRhs<DType>::reduce_last_dim;
124
125
126
127
128
129

template <typename DType>
struct Dot {
  static constexpr bool use_lhs = true;
  static constexpr bool use_rhs = true;
  static constexpr bool reduce_last_dim = true;
130
131
  static __device__ __forceinline__ DType
  Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
132
    DType rst = static_cast<DType>(0.0f);
133
134
135
136
137
138
    for (int64_t i = 0; i < len; ++i) {
      rst += lhs[i] * rhs[i];
    }
    return rst;
  }
};
139
140
141
142
143
144
template <typename DType>
constexpr bool Dot<DType>::use_lhs;
template <typename DType>
constexpr bool Dot<DType>::use_rhs;
template <typename DType>
constexpr bool Dot<DType>::reduce_last_dim;
145

146
}  // end of namespace binary
147

148
/////////////////////////// CUDA reduce operators //////////////////////////////
149
namespace reduce {
150
template <typename Idx, typename DType, bool atomic>
151
152
153
struct _Sum {
  static constexpr __host__ __device__ __forceinline__ DType zero() {
    return 0.;
154
  }
155
156
  static constexpr bool require_arg = false;
  static __device__ __forceinline__ void Call(
157
158
      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
      Idx eid) {
159
160
161
162
163
164
    if (!atomic) {
      *out_buf += val;
    } else {
      cuda::AtomicAdd(out_buf, val);
    }
  }
165
  static __device__ __forceinline__ void Call(
166
      DType *out_buf, Idx *arg_buf, DType val, Idx id) {
167
168
169
170
171
172
    if (!atomic) {
      *out_buf += val;
    } else {
      cuda::AtomicAdd(out_buf, val);
    }
  }
173
174
175
  static __device__ __forceinline__ void CallArg(
      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
      Idx uid, Idx eid) {}
176
177
};

178
179
template <typename Idx, typename DType, bool atomic = false>
struct Sum : _Sum<Idx, DType, atomic> {};
180
181

template <typename Idx, bool atomic>
182
183
struct Sum<Idx, __half, atomic> : _Sum<Idx, __half, atomic> {
  static constexpr __host__ __device__ __forceinline__ __half zero() {
184
    return __float2half_rn(0.);
185
  }
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Sum<Idx, __half, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Sum<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Sum<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
208
};
209
210
211

#if BF16_ENABLED
template <typename Idx, bool atomic>
sangwzh's avatar
sangwzh committed
212
213
214
struct Sum<Idx, __hip_bfloat16, atomic> : _Sum<Idx, __hip_bfloat16, atomic> {
  static constexpr __host__ __device__ __forceinline__ __hip_bfloat16 zero() {
    return __float2bfloat16(0.);
215
  }
216
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
217
218
219
      __hip_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __hip_bfloat16 val, Idx uid, Idx eid) {
    _Sum<Idx, __hip_bfloat16, atomic>::Call(
220
221
222
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
223
224
      __hip_bfloat16 *out_buf, Idx *arg_buf, __hip_bfloat16 val, Idx id) {
    _Sum<Idx, __hip_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
225
226
227
228
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
sangwzh's avatar
sangwzh committed
229
      __hip_bfloat16 val, Idx uid, Idx eid) {
230
231
232
233
    _Sum<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
234
      float *out_buf, Idx *arg_buf, __hip_bfloat16 val, Idx id) {
235
236
237
    _Sum<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
238
239
};
#endif  // BF16_ENABLED
240

241
template <typename Idx, typename DType, bool atomic>
242
243
244
struct _Max {
  static constexpr __host__ __device__ __forceinline__ DType zero() {
    return -std::numeric_limits<DType>::infinity();
245
  }
246
247
  static constexpr bool require_arg = true;
  static __device__ __forceinline__ void Call(
248
249
      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
      Idx eid) {
250
251
252
253
254
255
256
257
258
259
    if (!atomic) {
      if (*out_buf < val) {
        *out_buf = val;
        *arg_u_buf = uid;
        *arg_e_buf = eid;
      }
    } else {
      cuda::AtomicMax(out_buf, val);
    }
  }
260
  static __device__ __forceinline__ void Call(
261
      DType *out_buf, Idx *arg_buf, DType val, Idx id) {
262
263
264
265
266
267
268
269
270
    if (!atomic) {
      if (*out_buf < val) {
        *out_buf = val;
        *arg_buf = id;
      }
    } else {
      cuda::AtomicMax(out_buf, val);
    }
  }
271
272
273
  static __device__ __forceinline__ void CallArg(
      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
      Idx uid, Idx eid) {
274
275
    if (atomic) {
      if (val == val_ref) {
276
277
        if (arg_u_buf) arg_u_buf[fid] = uid;
        if (arg_e_buf) arg_e_buf[fid] = eid;
278
279
280
281
282
      }
    }
  }
};

283
284
template <typename Idx, typename DType, bool atomic = false>
struct Max : _Max<Idx, DType, atomic> {};
285

286
template <typename Idx, bool atomic>
287
288
struct Max<Idx, __half, atomic> : _Max<Idx, __half, atomic> {
  static constexpr __host__ __device__ __forceinline__ __half zero() {
289
    return __float2half_rn(-6.550400e+04f);
290
  }
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Max<Idx, __half, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Max<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Max<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
313
};
314
315

#if BF16_ENABLED
316
template <typename Idx, bool atomic>
sangwzh's avatar
sangwzh committed
317
318
319
struct Max<Idx, __hip_bfloat16, atomic> : _Max<Idx, __hip_bfloat16, atomic> {
  static constexpr __host__ __device__ __forceinline__ __hip_bfloat16 zero() {
    return __float2bfloat16(-std::numeric_limits<float>::infinity());
320
  }
321
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
322
323
324
      __hip_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __hip_bfloat16 val, Idx uid, Idx eid) {
    _Max<Idx, __hip_bfloat16, atomic>::Call(
325
326
327
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
328
329
      __hip_bfloat16 *out_buf, Idx *arg_buf, __hip_bfloat16 val, Idx id) {
    _Max<Idx, __hip_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
330
331
332
333
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
sangwzh's avatar
sangwzh committed
334
      __hip_bfloat16 val, Idx uid, Idx eid) {
335
336
337
338
    _Max<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
339
      float *out_buf, Idx *arg_buf, __hip_bfloat16 val, Idx id) {
340
341
342
    _Max<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
343
344
};
#endif  // BF16_ENABLED
345

346
template <typename Idx, typename DType, bool atomic>
347
348
349
struct _Min {
  static constexpr __host__ __device__ __forceinline__ DType zero() {
    return std::numeric_limits<DType>::infinity();
350
  }
351
352
  static constexpr bool require_arg = true;
  static __device__ __forceinline__ void Call(
353
354
      DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
      Idx eid) {
355
356
357
358
359
360
361
362
363
364
    if (!atomic) {
      if (*out_buf > val) {
        *out_buf = val;
        *arg_u_buf = uid;
        *arg_e_buf = eid;
      }
    } else {
      cuda::AtomicMin(out_buf, val);
    }
  }
365
  static __device__ __forceinline__ void Call(
366
      DType *out_buf, Idx *arg_buf, DType val, Idx id) {
367
368
369
370
371
372
373
374
375
    if (!atomic) {
      if (*out_buf > val) {
        *out_buf = val;
        *arg_buf = id;
      }
    } else {
      cuda::AtomicMin(out_buf, val);
    }
  }
376
377
378
  static __device__ __forceinline__ void CallArg(
      Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
      Idx uid, Idx eid) {
379
380
    if (atomic) {
      if (val == val_ref) {
381
382
        if (arg_u_buf) arg_u_buf[fid] = uid;
        if (arg_e_buf) arg_e_buf[fid] = eid;
383
384
385
386
      }
    }
  }
};
387

388
389
template <typename Idx, typename DType, bool atomic = false>
struct Min : _Min<Idx, DType, atomic> {};
390

391
template <typename Idx, bool atomic>
392
393
struct Min<Idx, __half, atomic> : _Min<Idx, __half, atomic> {
  static constexpr __host__ __device__ __forceinline__ __half zero() {
394
    return __float2half_rn(6.550400e+04f);
395
  }
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Min<Idx, __half, atomic>::Call(
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
      __half *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Min<Idx, __half, atomic>::Call(out_buf, arg_buf, val, id);
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __half val, Idx uid, Idx eid) {
    _Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_buf, __half val, Idx id) {
    _Min<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
418
};
419
420

#if BF16_ENABLED
421
template <typename Idx, bool atomic>
sangwzh's avatar
sangwzh committed
422
423
424
struct Min<Idx, __hip_bfloat16, atomic> : _Min<Idx, __hip_bfloat16, atomic> {
  static constexpr __host__ __device__ __forceinline__ __hip_bfloat16 zero() {
    return __float2bfloat16(std::numeric_limits<float>::infinity());
425
  }
426
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
427
428
429
      __hip_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
      __hip_bfloat16 val, Idx uid, Idx eid) {
    _Min<Idx, __hip_bfloat16, atomic>::Call(
430
431
432
        out_buf, arg_u_buf, arg_e_buf, val, uid, eid);
  }
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
433
434
      __hip_bfloat16 *out_buf, Idx *arg_buf, __hip_bfloat16 val, Idx id) {
    _Min<Idx, __hip_bfloat16, atomic>::Call(out_buf, arg_buf, val, id);
435
436
437
438
  }
  // sometimes we have to use float in reduction for better precision
  static __device__ __forceinline__ void Call(
      float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
sangwzh's avatar
sangwzh committed
439
      __hip_bfloat16 val, Idx uid, Idx eid) {
440
441
442
443
    _Min<Idx, float, atomic>::Call(out_buf, arg_u_buf, arg_e_buf,
        static_cast<float>(val), uid, eid);
  }
  static __device__ __forceinline__ void Call(
sangwzh's avatar
sangwzh committed
444
      float *out_buf, Idx *arg_buf, __hip_bfloat16 val, Idx id) {
445
446
447
    _Min<Idx, float, atomic>::Call(out_buf, arg_buf,
        static_cast<float>(val), id);
  }
448
449
};
#endif  // BF16_ENABLED
450
451
452
453
454
455
456
457

}  // namespace reduce

}  // namespace cuda
}  // namespace aten
}  // namespace dgl

#endif  // DGL_ARRAY_CUDA_FUNCTOR_CUH_