math.h 3.9 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_

namespace transformer_engine {
11
12

struct Empty {};
13

14
15
16
17
18
struct ClampedSwiGLUParam {
  float limit;
  float alpha = 1.702f;  // Default value for QuickGELU
};

19
template <typename OType, typename IType>
20
__device__ inline OType gelu(const IType val, const Empty&) {
21
22
  const float cval = val;
  return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
23
24
25
}

template <typename OType, typename IType>
26
__device__ inline OType dgelu(const IType val, const Empty&) {
27
28
29
30
  const float cval = val;
  const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
  return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
         0.5f * (1.f + tanh_out);
31
32
}

33
34
template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) {
35
36
  const float cval = val;
  return 1.f / (1.f + expf(-cval));
37
38
39
40
}

template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) {
41
42
43
  const float cval = val;
  const float s = sigmoid<float, float>(cval, e);
  return s * (1.f - s);
44
45
}

46
47
48
49
50
51
52
template <typename OType, typename IType>
__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) {
  const float cval = val;
  Empty e = {};
  return cval * sigmoid<float, float>(alpha * cval, e);
}

53
54
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
55
56
57
58
59
  return qgelu_with_alpha<OType, IType>(val, 1.702f);
}

template <typename OType, typename IType>
__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) {
60
  const float cval = val;
61
62
63
  Empty e = {};
  return alpha * cval * dsigmoid<float, float>(alpha * cval, e) +
         sigmoid<float, float>(alpha * cval, e);
64
65
66
67
}

template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
68
  return dqgelu_with_alpha<OType, IType>(val, 1.702f);
69
70
}

71
template <typename OType, typename IType>
72
__device__ inline OType silu(const IType val, const Empty& e) {
73
74
  const float cval = val;
  return cval * sigmoid<float, float>(cval, e);
75
76
}

77
78
79
80
81
82
template <typename OType, typename IType>
__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) {
  const float cval = min(p.limit, static_cast<float>(val));  // Clamping
  return qgelu_with_alpha<OType, float>(cval, p.alpha);
}

83
template <typename OType, typename IType>
84
__device__ inline OType dsilu(const IType val, const Empty& e) {
85
86
  const float cval = val;
  return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
87
88
}

89
90
91
92
93
94
95
96
template <typename OType, typename IType>
__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) {
  const bool dclamp_val = static_cast<float>(val) <= p.limit;
  const float clamp_val = min(static_cast<float>(val), p.limit);
  const float dsilu_val = dqgelu_with_alpha<OType, float>(clamp_val, p.alpha);
  return dclamp_val ? dsilu_val : 0.0f;
}

97
template <typename OType, typename IType>
98
99
__device__ inline OType relu(IType value, const Empty&) {
  return fmaxf(value, 0.f);
100
101
102
}

template <typename OType, typename IType>
103
104
__device__ inline OType drelu(IType value, const Empty&) {
  return value > 0.f ? 1.f : 0.f;
105
106
}

107
template <typename OType, typename IType>
108
109
__device__ inline OType srelu(IType value, const Empty&) {
  return value > 0 ? value * value : 0.f;
110
111
112
}

template <typename OType, typename IType>
113
114
__device__ inline OType dsrelu(IType value, const Empty&) {
  return fmaxf(2.f * value, 0.f);
115
}
116

117
118
119
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_