math.h 3.99 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_

namespace transformer_engine {
11
12

struct Empty {};
13

14
15
16
17
18
struct ClampedSwiGLUParam {
  float limit;
  float alpha = 1.702f;  // Default value for QuickGELU
};

19
template <typename OType, typename IType>
20
__device__ inline OType gelu(const IType val, const Empty&) {
21
22
  const float cval = val;
  return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
23
24
25
}

template <typename OType, typename IType>
26
__device__ inline OType dgelu(const IType val, const Empty&) {
27
28
29
30
  const float cval = val;
  const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
  return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
         0.5f * (1.f + tanh_out);
31
32
}

33
34
template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) {
35
36
  const float cval = val;
  return 1.f / (1.f + expf(-cval));
37
38
}

39
40
__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); }

41
42
template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) {
43
44
45
  const float cval = val;
  const float s = sigmoid<float, float>(cval, e);
  return s * (1.f - s);
46
47
}

48
49
50
51
52
53
54
template <typename OType, typename IType>
__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) {
  const float cval = val;
  Empty e = {};
  return cval * sigmoid<float, float>(alpha * cval, e);
}

55
56
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
57
58
59
60
61
  return qgelu_with_alpha<OType, IType>(val, 1.702f);
}

template <typename OType, typename IType>
__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) {
62
  const float cval = val;
63
64
65
  Empty e = {};
  return alpha * cval * dsigmoid<float, float>(alpha * cval, e) +
         sigmoid<float, float>(alpha * cval, e);
66
67
68
69
}

template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
70
  return dqgelu_with_alpha<OType, IType>(val, 1.702f);
71
72
}

73
template <typename OType, typename IType>
74
__device__ inline OType silu(const IType val, const Empty& e) {
75
76
  const float cval = val;
  return cval * sigmoid<float, float>(cval, e);
77
78
}

79
80
81
82
83
84
template <typename OType, typename IType>
__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) {
  const float cval = min(p.limit, static_cast<float>(val));  // Clamping
  return qgelu_with_alpha<OType, float>(cval, p.alpha);
}

85
template <typename OType, typename IType>
86
__device__ inline OType dsilu(const IType val, const Empty& e) {
87
88
  const float cval = val;
  return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
89
90
}

91
92
93
94
95
96
97
98
template <typename OType, typename IType>
__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) {
  const bool dclamp_val = static_cast<float>(val) <= p.limit;
  const float clamp_val = min(static_cast<float>(val), p.limit);
  const float dsilu_val = dqgelu_with_alpha<OType, float>(clamp_val, p.alpha);
  return dclamp_val ? dsilu_val : 0.0f;
}

99
template <typename OType, typename IType>
100
101
__device__ inline OType relu(IType value, const Empty&) {
  return fmaxf(value, 0.f);
102
103
104
}

template <typename OType, typename IType>
105
106
__device__ inline OType drelu(IType value, const Empty&) {
  return value > 0.f ? 1.f : 0.f;
107
108
}

109
template <typename OType, typename IType>
110
111
__device__ inline OType srelu(IType value, const Empty&) {
  return value > 0 ? value * value : 0.f;
112
113
114
}

template <typename OType, typename IType>
115
116
__device__ inline OType dsrelu(IType value, const Empty&) {
  return fmaxf(2.f * value, 0.f);
117
}
118

119
120
121
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_