gelu.cu 8.97 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <iostream>
#include "../utils.cuh"
#include "../common.h"
#include <cstdlib>
#include <../util/vectorized_pointwise.h>
15
#include "../util/math.h"
Przemek Tredak's avatar
Przemek Tredak committed
16
17
18

namespace transformer_engine {

19
20
21
22
23
24
25
void gelu(const Tensor &input,
          Tensor *output,
          cudaStream_t stream) {
  CheckInputTensor(input, "gelu_input");
  CheckOutputTensor(*output, "gelu_output");
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
  const size_t tot_elts = product(input.data.shape);
Przemek Tredak's avatar
Przemek Tredak committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
      VectorizedUnaryKernelLauncher<nvec, Empty, gelu<fp32, fp32> >(
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
        tot_elts,
        Empty(),
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
40
41
}

42
43
44
45
46
47
48
void dgelu(const Tensor &grad,
           const Tensor &input,
           Tensor *output,
           cudaStream_t stream) {
  CheckInputTensor(input, "dgelu_input");
  CheckInputTensor(grad, "dgelu_input_grad");
  CheckOutputTensor(*output, "dgelu_output");
49
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
50
51
52
  NVTE_CHECK(input.data.dtype == grad.data.dtype,
             "Input and incoming gradient types must match.");
  const size_t tot_elts = product(input.data.shape);
53
54
55

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
Przemek Tredak's avatar
Przemek Tredak committed
56
      constexpr int nvec = 32 / sizeof(IType);
57
58
      VectorizedUnaryGradKernelLauncher<nvec, Empty, dgelu<fp32, fp32>>(
        reinterpret_cast<const IType*>(grad.data.dptr),
59
60
61
62
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
Przemek Tredak's avatar
Przemek Tredak committed
63
64
65
66
67
68
69
        tot_elts,
        {},
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

70
71
72
void geglu(const Tensor &input,
           Tensor *output,
           cudaStream_t stream) {
73
74
75
76
77
78
79
  CheckInputTensor(input, "geglu_input");
  CheckOutputTensor(*output, "geglu_output");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
             "Input shape[0] must be equal to output shape[0].");
  NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
80
             "Input shape[1] must be 2x larger than output shape[1].");
81
82
83
84

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
85
      GatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>>(
86
87
88
89
90
91
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
        output->data.shape[0],
        output->data.shape[1],
92
        {},
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

void dgeglu(const Tensor &grad,
            const Tensor &input,
            Tensor *output,
            cudaStream_t stream) {
  CheckInputTensor(grad, "dgeglu_grad");
  CheckInputTensor(input, "dgeglu_input");
  CheckOutputTensor(*output, "dgeglu_output");
  NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
             "Output shape[0] must be equal to grad shape[0].");
  NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
111
             "Output shape[1] must be 2x larger than grad shape[1].");
112
113
114
115
116
117
  NVTE_CHECK(input.data.shape == output->data.shape,
             "Input and output shapes must match.");

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
118
      DGatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
119
120
121
122
123
        reinterpret_cast<const IType*>(grad.data.dptr),
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        grad.data.shape[0],
        grad.data.shape[1],
124
        {},
125
126
127
128
129
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
void qgelu(const Tensor &input,
  Tensor *output,
  cudaStream_t stream) {
  CheckInputTensor(input, "qgelu_input");
  CheckOutputTensor(*output, "qgelu_output");
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
  const size_t tot_elts = product(input.data.shape);

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
      VectorizedUnaryKernelLauncher<nvec, Empty, qgelu<fp32, fp32> >(
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
        tot_elts,
        Empty(),
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

void dqgelu(const Tensor &grad,
   const Tensor &input,
   Tensor *output,
   cudaStream_t stream) {
  CheckInputTensor(input, "dqgelu_input");
  CheckInputTensor(grad, "dqgelu_input_grad");
  CheckOutputTensor(*output, "dqgelu_output");
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
  NVTE_CHECK(input.data.dtype == grad.data.dtype,
      "Input and incoming gradient types must match.");
  const size_t tot_elts = product(input.data.shape);

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
      VectorizedUnaryGradKernelLauncher<nvec, Empty, dqgelu<fp32, fp32>>(
        reinterpret_cast<const IType*>(grad.data.dptr),
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
        tot_elts,
        {},
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

Przemek Tredak's avatar
Przemek Tredak committed
181
182
183
184
185
}  // namespace transformer_engine

void nvte_gelu(const NVTETensor input,
               NVTETensor output,
               cudaStream_t stream) {
186
  NVTE_API_CALL(nvte_gelu);
Przemek Tredak's avatar
Przemek Tredak committed
187
  using namespace transformer_engine;
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
  gelu(*reinterpret_cast<const Tensor*>(input),
       reinterpret_cast<Tensor*>(output),
       stream);
}

void nvte_dgelu(const NVTETensor grad,
                const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_dgelu);
  using namespace transformer_engine;
  dgelu(*reinterpret_cast<const Tensor*>(grad),
        *reinterpret_cast<const Tensor*>(input),
        reinterpret_cast<Tensor*>(output),
        stream);
Przemek Tredak's avatar
Przemek Tredak committed
203
}
204
205
206
207

void nvte_geglu(const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
208
  NVTE_API_CALL(nvte_geglu);
209
  using namespace transformer_engine;
210
211
212
  geglu(*reinterpret_cast<const Tensor*>(input),
        reinterpret_cast<Tensor*>(output),
        stream);
213
214
215
216
217
218
}

void nvte_dgeglu(const NVTETensor grad,
                 const NVTETensor input,
                 NVTETensor output,
                 cudaStream_t stream) {
219
  NVTE_API_CALL(nvte_dgeglu);
220
221
222
223
224
225
  using namespace transformer_engine;
  dgeglu(*reinterpret_cast<const Tensor*>(grad),
         *reinterpret_cast<const Tensor*>(input),
         reinterpret_cast<Tensor*>(output),
         stream);
}
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

void nvte_qgelu(const NVTETensor input,
  NVTETensor output,
  cudaStream_t stream) {
  NVTE_API_CALL(nvte_qgelu);
  using namespace transformer_engine;
  qgelu(*reinterpret_cast<const Tensor*>(input),
        reinterpret_cast<Tensor*>(output),
        stream);
}

void nvte_dqgelu(const NVTETensor grad,
   const NVTETensor input,
   NVTETensor output,
   cudaStream_t stream) {
  NVTE_API_CALL(nvte_dqgelu);
  using namespace transformer_engine;
  dqgelu(*reinterpret_cast<const Tensor*>(grad),
        *reinterpret_cast<const Tensor*>(input),
        reinterpret_cast<Tensor*>(output),
        stream);
}