gelu.cu 6.43 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/activation.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <iostream>
#include "../utils.cuh"
#include "../common.h"
#include <cstdlib>
#include <../util/vectorized_pointwise.h>
15
#include "../util/math.h"
Przemek Tredak's avatar
Przemek Tredak committed
16
17
18

namespace transformer_engine {

19
20
21
22
23
24
25
void gelu(const Tensor &input,
          Tensor *output,
          cudaStream_t stream) {
  CheckInputTensor(input, "gelu_input");
  CheckOutputTensor(*output, "gelu_output");
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
  const size_t tot_elts = product(input.data.shape);
Przemek Tredak's avatar
Przemek Tredak committed
26

27
28
29
30
31
32
33
34
35
36
37
38
39
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
      VectorizedUnaryKernelLauncher<nvec, Empty, gelu<fp32, fp32> >(
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
        tot_elts,
        Empty(),
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
40
41
}

42
43
44
45
46
47
48
void dgelu(const Tensor &grad,
           const Tensor &input,
           Tensor *output,
           cudaStream_t stream) {
  CheckInputTensor(input, "dgelu_input");
  CheckInputTensor(grad, "dgelu_input_grad");
  CheckOutputTensor(*output, "dgelu_output");
49
  NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
50
51
52
  NVTE_CHECK(input.data.dtype == grad.data.dtype,
             "Input and incoming gradient types must match.");
  const size_t tot_elts = product(input.data.shape);
53
54
55

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
Przemek Tredak's avatar
Przemek Tredak committed
56
      constexpr int nvec = 32 / sizeof(IType);
57
58
      VectorizedUnaryGradKernelLauncher<nvec, Empty, dgelu<fp32, fp32>>(
        reinterpret_cast<const IType*>(grad.data.dptr),
59
60
61
62
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
Przemek Tredak's avatar
Przemek Tredak committed
63
64
65
66
67
68
69
        tot_elts,
        {},
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

70
71
72
void geglu(const Tensor &input,
           Tensor *output,
           cudaStream_t stream) {
73
74
75
76
77
78
79
  CheckInputTensor(input, "geglu_input");
  CheckOutputTensor(*output, "geglu_output");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
             "Input shape[0] must be equal to output shape[0].");
  NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
80
             "Input shape[1] must be 2x larger than output shape[1].");
81
82
83
84

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
85
      GatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>>(
86
87
88
89
90
91
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        reinterpret_cast<const fp32*>(output->scale.dptr),
        reinterpret_cast<fp32*>(output->amax.dptr),
        output->data.shape[0],
        output->data.shape[1],
92
        {},
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

void dgeglu(const Tensor &grad,
            const Tensor &input,
            Tensor *output,
            cudaStream_t stream) {
  CheckInputTensor(grad, "dgeglu_grad");
  CheckInputTensor(input, "dgeglu_input");
  CheckOutputTensor(*output, "dgeglu_output");
  NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
  NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
  NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
  NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
             "Output shape[0] must be equal to grad shape[0].");
  NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
111
             "Output shape[1] must be 2x larger than grad shape[1].");
112
113
114
115
116
117
  NVTE_CHECK(input.data.shape == output->data.shape,
             "Input and output shapes must match.");

  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
118
      DGatedActivationKernelLauncher<nvec, fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
119
120
121
122
123
        reinterpret_cast<const IType*>(grad.data.dptr),
        reinterpret_cast<const IType*>(input.data.dptr),
        reinterpret_cast<OType*>(output->data.dptr),
        grad.data.shape[0],
        grad.data.shape[1],
124
        {},
125
126
127
128
129
        stream);
    );  // NOLINT(*)
  );  // NOLINT(*)
}

Przemek Tredak's avatar
Przemek Tredak committed
130
131
132
133
134
}  // namespace transformer_engine

void nvte_gelu(const NVTETensor input,
               NVTETensor output,
               cudaStream_t stream) {
135
  NVTE_API_CALL(nvte_gelu);
Przemek Tredak's avatar
Przemek Tredak committed
136
  using namespace transformer_engine;
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
  gelu(*reinterpret_cast<const Tensor*>(input),
       reinterpret_cast<Tensor*>(output),
       stream);
}

void nvte_dgelu(const NVTETensor grad,
                const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
  NVTE_API_CALL(nvte_dgelu);
  using namespace transformer_engine;
  dgelu(*reinterpret_cast<const Tensor*>(grad),
        *reinterpret_cast<const Tensor*>(input),
        reinterpret_cast<Tensor*>(output),
        stream);
Przemek Tredak's avatar
Przemek Tredak committed
152
}
153
154
155
156

void nvte_geglu(const NVTETensor input,
                NVTETensor output,
                cudaStream_t stream) {
157
  NVTE_API_CALL(nvte_geglu);
158
  using namespace transformer_engine;
159
160
161
  geglu(*reinterpret_cast<const Tensor*>(input),
        reinterpret_cast<Tensor*>(output),
        stream);
162
163
164
165
166
167
}

void nvte_dgeglu(const NVTETensor grad,
                 const NVTETensor input,
                 NVTETensor output,
                 cudaStream_t stream) {
168
  NVTE_API_CALL(nvte_dgeglu);
169
170
171
172
173
174
  using namespace transformer_engine;
  dgeglu(*reinterpret_cast<const Tensor*>(grad),
         *reinterpret_cast<const Tensor*>(input),
         reinterpret_cast<Tensor*>(output),
         stream);
}