math.h 2.88 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_

namespace transformer_engine {
11
12

struct Empty {};
13
14

template <typename OType, typename IType>
15
__device__ inline OType gelu(const IType val, const Empty&) {
16
17
18
19
20
    const float cval = val;
    return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
}

template <typename OType, typename IType>
21
__device__ inline OType dgelu(const IType val, const Empty&) {
22
23
24
25
26
27
28
    const float cval = val;
    const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
    return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
                          (0.79788456f + 0.1070322243f * cval * cval)) +
           0.5f * (1.f + tanh_out);
}

29
30
31
32
33
34
35
36
37
38
39
40
41
template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) {
    const float cval = val;
    return 1.f / (1.f + expf(-cval));
}

template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) {
    const float cval = val;
    const float s = sigmoid<float, float>(cval, e);
    return s * (1.f - s);
}

42
43
44
45
46
47
48
49
50
51
52
53
54
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
    const float cval = val;
    return cval * sigmoid<float, float>(1.702f * cval, e);
}

template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
    const float cval = val;
    return cval * dsigmoid<float, float>(1.702f * cval, e) +
                   sigmoid<float, float>(1.702f * cval, e);
}

55
template <typename OType, typename IType>
56
__device__ inline OType silu(const IType val, const Empty& e) {
57
58
59
60
61
    const float cval = val;
    return cval * sigmoid<float, float>(cval, e);
}

template <typename OType, typename IType>
62
__device__ inline OType dsilu(const IType val, const Empty& e) {
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    const float cval = val;
    return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
}

template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty &) {
    return fmaxf(value, 0.f);
}

template <typename OType, typename IType>
__device__ inline OType drelu(IType value, const Empty &) {
    return value > 0.f ? 1.f : 0.f;
}

77
78
79
80
81
82
83
84
85
template <typename OType, typename IType>
__device__ inline OType srelu(IType value, const Empty &) {
    return value > 0 ? value * value : 0.f;
}

template <typename OType, typename IType>
__device__ inline OType dsrelu(IType value, const Empty &) {
    return fmaxf(2.f * value, 0.f);
}
86

87
88
89
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_