parrots_cuda_helper.hpp 4.25 KB
Newer Older
1
2
3
4
5
6
7
#ifndef PARROTS_CUDA_HELPER
#define PARROTS_CUDA_HELPER

#include <cuda.h>
#include <float.h>

#include <parrots/darray/darraymath.hpp>
8
#include <parrots/darray/mathfunctions.hpp>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <parrots/extension.hpp>
#include <parrots/foundation/darrayutil.hpp>
#include <parrots/foundation/exceptions.hpp>
#include <parrots/foundation/float16.hpp>
#include <parrots/foundation/mathfunction.hpp>

#include "common_cuda_helper.hpp"
#include "parrots_cudawarpfunction.cuh"

using namespace parrots;
using phalf = float16;

#define __PHALF(x) (x.y)

#define PARROTS_CUDA_CHECK(exp)                         \
  do {                                                  \
    cudaError_t err = exp;                              \
    if (err != cudaSuccess) {                           \
      fprintf(stderr, "cudaCheckError() failed : %s\n", \
              cudaGetErrorString(err));                 \
      exit(-1);                                         \
    }                                                   \
  } while (0)

#define PARROTS_PRIVATE_CASE_TYPE(prim_type, type, ...) \
  case prim_type: {                                     \
    using scalar_t = type;                              \
    return __VA_ARGS__();                               \
  }

#define PARROTS_DISPATCH_FLOATING_TYPES(TYPE, ...)                  \
  [&] {                                                             \
    const auto& the_type = TYPE;                                    \
    switch (the_type) {                                             \
      PARROTS_PRIVATE_CASE_TYPE(Prim::Float64, double, __VA_ARGS__) \
      PARROTS_PRIVATE_CASE_TYPE(Prim::Float32, float, __VA_ARGS__)  \
      default:                                                      \
        PARROTS_NOTSUPPORTED;                                       \
    }                                                               \
  }()

#define PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, ...)          \
  [&] {                                                              \
    const auto& the_type = TYPE;                                     \
    switch (the_type) {                                              \
      PARROTS_PRIVATE_CASE_TYPE(Prim::Float64, double, __VA_ARGS__)  \
      PARROTS_PRIVATE_CASE_TYPE(Prim::Float32, float, __VA_ARGS__)   \
      PARROTS_PRIVATE_CASE_TYPE(Prim::Float16, float16, __VA_ARGS__) \
      default:                                                       \
        PARROTS_NOTSUPPORTED;                                        \
    }                                                                \
  }()

/** atomicAdd **/
limm's avatar
limm committed
63
#if defined(__CUDACC__) && __CUDACC__ < 600
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

static __inline__ __device__ double atomicAdd(double* address, double val) {
  unsigned long long int* address_as_ull = (unsigned long long int*)address;
  unsigned long long int old = *address_as_ull, assumed;
  if (val == 0.0) return __longlong_as_double(old);
  do {
    assumed = old;
    old = atomicCAS(address_as_ull, assumed,
                    __double_as_longlong(val + __longlong_as_double(assumed)));
  } while (assumed != old);
  return __longlong_as_double(old);
}

#endif

static __inline__ __device__ float16 atomicAdd(float16* address, float16 val) {
  unsigned int* aligned =
      (unsigned int*)((size_t)address - ((size_t)address & 2));
  unsigned int old = *aligned;
  unsigned int assumed;
  unsigned short old_as_us;
  do {
    assumed = old;
    old_as_us =
        (unsigned short)((size_t)address & 2 ? old >> 16 : old & 0xffff);

#if __CUDACC_VER_MAJOR__ >= 9
    float16 tmp;
    tmp.x = old_as_us;
    float16 sum = tmp + val;
    unsigned short sum_as_us = sum.x;
//         half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us))
//         + (float)(val)); unsigned short sum_as_us = __half_as_ushort(sum);
#else
    unsigned short sum_as_us =
        __float2half_rn(__half2float(old_as_us) + (float)(val));
#endif

    unsigned int sum_as_ui = (size_t)address & 2
                                 ? (sum_as_us << 16) | (old & 0xffff)
                                 : (old & 0xffff0000) | sum_as_us;
    old = atomicCAS(aligned, assumed, sum_as_ui);
  } while (assumed != old);
  //__half_raw raw = {old_as_us};
  // return float16(raw);
  return *reinterpret_cast<float16*>(&old_as_us);
}
#endif  // PARROTS_CUDA_HELPER