math.h 2.8 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_

namespace transformer_engine {
11
12

struct Empty {};
13
14

template <typename OType, typename IType>
15
__device__ inline OType gelu(const IType val, const Empty&) {
16
17
  const float cval = val;
  return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
18
19
20
}

template <typename OType, typename IType>
21
__device__ inline OType dgelu(const IType val, const Empty&) {
22
23
24
25
  const float cval = val;
  const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
  return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
         0.5f * (1.f + tanh_out);
26
27
}

28
29
template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) {
30
31
  const float cval = val;
  return 1.f / (1.f + expf(-cval));
32
33
34
35
}

template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) {
36
37
38
  const float cval = val;
  const float s = sigmoid<float, float>(cval, e);
  return s * (1.f - s);
39
40
}

41
42
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
43
44
  const float cval = val;
  return cval * sigmoid<float, float>(1.702f * cval, e);
45
46
47
48
}

template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
49
  const float cval = val;
50
51
  return 1.702f * cval * dsigmoid<float, float>(1.702f * cval, e) +
         sigmoid<float, float>(1.702f * cval, e);
52
53
}

54
template <typename OType, typename IType>
55
__device__ inline OType silu(const IType val, const Empty& e) {
56
57
  const float cval = val;
  return cval * sigmoid<float, float>(cval, e);
58
59
60
}

template <typename OType, typename IType>
61
__device__ inline OType dsilu(const IType val, const Empty& e) {
62
63
  const float cval = val;
  return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
64
65
66
}

template <typename OType, typename IType>
67
68
__device__ inline OType relu(IType value, const Empty&) {
  return fmaxf(value, 0.f);
69
70
71
}

template <typename OType, typename IType>
72
73
__device__ inline OType drelu(IType value, const Empty&) {
  return value > 0.f ? 1.f : 0.f;
74
75
}

76
template <typename OType, typename IType>
77
78
__device__ inline OType srelu(IType value, const Empty&) {
  return value > 0 ? value * value : 0.f;
79
80
81
}

template <typename OType, typename IType>
82
83
__device__ inline OType dsrelu(IType value, const Empty&) {
  return fmaxf(2.f * value, 0.f);
84
}
85

86
87
88
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_