gelu.cu 5.33 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <iostream>
#include "../utils.cuh"
#include "../common.h"
#include <cstdlib>
#include <../util/vectorized_pointwise.h>
15
#include "../util/math.h"
Przemek Tredak's avatar
Przemek Tredak committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

namespace transformer_engine {

namespace detail {

struct GELUParam {};

__device__ inline fp32 gelu(fp32 value, const GELUParam &) {
  return value * (0.5F + 0.5F * tanhf(value * (0.79788456F + 0.03567741F * value * value)));
}

}

void gelu_cast(const Tensor &input,
               Tensor *output,
               cudaStream_t stream) {
32
33
34
35
36
37
38
39
40
  CheckInputTensor(input, "gelu_input");
  CheckOutputTensor(*output, "gelu_output");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
  const size_t tot_elts = input.data.shape[1] * input.data.shape[0];

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
Przemek Tredak's avatar
Przemek Tredak committed
41
42
      constexpr int nvec = 32 / sizeof(IType);
      VectorizedUnaryKernelLauncher<nvec, detail::GELUParam, detail::gelu>(
43
44
45
46
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
Przemek Tredak's avatar
Przemek Tredak committed
47
48
49
50
51
52
53
        tot_elts,
        {},
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
void geglu_cast(const Tensor &input,
                Tensor *output,
                cudaStream_t stream) {
  CheckInputTensor(input, "geglu_input");
  CheckOutputTensor(*output, "geglu_output");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
             "Input shape[0] must be equal to output shape[0].");
  NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
             "Input shape[1] must be twice than output shape[1].");

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
      GatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>>(
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
        output->data.shape[0],
        output->data.shape[1],
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

void dgeglu(const Tensor &grad,
            const Tensor &input,
            Tensor *output,
            cudaStream_t stream) {
  CheckInputTensor(grad, "dgeglu_grad");
  CheckInputTensor(input, "dgeglu_input");
  CheckOutputTensor(*output, "dgeglu_output");
  NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
             "Output shape[0] must be equal to grad shape[0].");
  NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
             "Output shape[1] must be twice than grad shape[1].");
  NVTE_CHECK(input.data.shape == output->data.shape,
             "Input and output shapes must match.");

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
      DGatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
        reinterpret_cast<const IType*>(grad.data.dptr),
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        grad.data.shape[0],
        grad.data.shape[1],
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

Przemek Tredak's avatar
Przemek Tredak committed
112
113
114
115
116
}  // namespace transformer_engine

void nvte_gelu(const NVTETensor input,
               NVTETensor output,
               cudaStream_t stream) {
117
  NVTE_API_CALL(nvte_gelu);
Przemek Tredak's avatar
Przemek Tredak committed
118
119
120
121
122
  using namespace transformer_engine;
  gelu_cast(*reinterpret_cast<const Tensor*>(input),
            reinterpret_cast<Tensor*>(output),
            stream);
}
123
124
125
126

void nvte_geglu(const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
127
  NVTE_API_CALL(nvte_geglu);
128
129
130
131
132
133
134
135
136
137
  using namespace transformer_engine;
  geglu_cast(*reinterpret_cast<const Tensor*>(input),
             reinterpret_cast<Tensor*>(output),
             stream);
}

void nvte_dgeglu(const NVTETensor grad,
                 const NVTETensor input,
                 NVTETensor output,
                 cudaStream_t stream) {
138
  NVTE_API_CALL(nvte_dgeglu);
139
140
141
142
143
144
  using namespace transformer_engine;
  dgeglu(*reinterpret_cast<const Tensor*>(grad),
         *reinterpret_cast<const Tensor*>(input),
         reinterpret_cast<Tensor*>(output),
         stream);
}