activation_kernels.cu 12.8 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

lvhan028's avatar
lvhan028 committed
17
18
19
20
#include "src/turbomind/kernels/activation_kernels.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"
Li Zhang's avatar
Li Zhang committed
21
22
23
24
25

#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#endif

lvhan028's avatar
lvhan028 committed
26
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

/* Gelu Activation */

__forceinline__ __device__ float copysignf_pos(float a, float b)
{
    float r;
    r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
    return r;
}

__inline__ __device__ float tanh_opt(float x)
{
#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000)
    float r;
    asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x));
    return r;
#else
    const float exp_val = -1.f * fabs(2 * x);
    return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#endif
}

template<typename T>
struct GeluActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val))));
        return val * cdf;
    }
};

template<>
struct GeluActivation<half2> {
    using return_type = half2;
    static __device__ __forceinline__ half2 apply(const half2& val)
    {
        half2  val_pow3 = __hmul2(val, __hmul2(val, val));
        float2 tmp_pow  = __half22float2(val_pow3);
        float2 tmp      = __half22float2(val);

        tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
        tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
        return __hmul2(val, __float22half2_rn(tmp));
    }
};

#ifdef ENABLE_BF16
template<>
struct GeluActivation<__nv_bfloat162> {
    using return_type = __nv_bfloat162;
    static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
    {
        __nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val));
        float2         tmp_pow  = bf1622float2(val_pow3);
        float2         tmp      = bf1622float2(val);

        tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x))));
        tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y))));
        return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y));
    }
};
#endif

/* Relu Activation */

template<typename T>
struct ReluActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        return val > static_cast<T>(0.0f) ? val : static_cast<T>(0.0f);
    }
};

template<>
struct ReluActivation<half2> {
    using return_type = half2;
    static __device__ __forceinline__ half2 apply(const half2& val)
    {
        const half zero_half = static_cast<half>(0.0f);
        return make_half2(val.x > zero_half ? val.x : zero_half, val.y > zero_half ? val.y : zero_half);
    }
};

#ifdef ENABLE_BF16
template<>
struct ReluActivation<__nv_bfloat162> {
    using return_type = __nv_bfloat162;
    static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val)
    {
        const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f);
        return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16);
    }
};
#endif

/* Silu Activation */

template<typename T>
struct SiluActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        return (T)((float)val / (1.0f + __expf((float)-val)));
    }
};

template<>
struct SiluActivation<half2> {
    using return_type = float2;
    static __device__ __forceinline__ float2 apply(const half2& val)
    {
        return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y));
    }
};

#ifdef ENABLE_BF16
template<>
struct SiluActivation<__nv_bfloat162> {
    using return_type = float2;
    static __device__ __forceinline__ float2 apply(const __nv_bfloat162& val)
    {
        return make_float2(SiluActivation<float>::apply(val.x), SiluActivation<float>::apply(val.y));
    }
};
#endif  // ENABLE_BF16

/* Identity Activation (= no activation) */

template<typename T>
struct IdentityActivation {
    using return_type = T;
    static __device__ __forceinline__ T apply(const T& val)
    {
        return val;
    }
};

// clang-format off
template<template<typename T> class Activation, typename T, typename BT>
__global__ void generic_activation(T*                      out,
                                   const BT*  __restrict   bias,
                                   const T*   __restrict   gated_weights,
                                   const BT*  __restrict   gated_bias,
                                   const int* __restrict   ia3_tasks,
                                   const T*   __restrict   ia3_weights,
                                   const int               int8_mode,
                                   const float* __restrict activation_in,
                                   const float* __restrict activation_out,
                                   const int* __restrict padding_offset,
                                   const int seq_len,
                                   int m,
                                   int n)
{
    constexpr size_t packed_elems = num_elems<T>::value;

    const bool with_bias = bias != nullptr;
    const bool with_gate = gated_weights != nullptr;
    // const bool with_ia3  = ia3_tasks != nullptr;

    using Act_T         = typename Activation<T>::return_type;
    using Float_T       = typename packed_as<float, packed_elems>::type;
    using Packed_Int8_t = typename packed_as<int8_t, packed_elems>::type;

    for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
        T val;
        if (int8_mode == 2) {
            // val = cuda_cast<T>(cuda_cast<Float_T>(reinterpret_cast<Packed_Int8_t*>(out)[id]) * activation_in[0]);
        }
        else {
            val = out[id];
        }

        T gated_val;
        if (with_gate) {
            gated_val = gated_weights[id];
        }

        // if (with_bias) {
        //     const T reg_bias = static_cast<T>(bias[id % n]);
        //     val              = val + reg_bias;

        //     if (with_gate) {
        //         const T reg_gated_bias = static_cast<T>(gated_bias[id % n]);
        //         gated_val              = gated_val + reg_gated_bias;
        //     }
        // }

        if (with_gate) {
            val = cuda_cast<T>(Activation<T>::apply(val) * cuda_cast<Act_T>(gated_val));
        }
        else {
            // val = cuda_cast<T>(Activation<T>::apply(val));
        }

        // if (with_ia3) {
        //     const int word_id = id / n;
        //     const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id];
        //     const int batch_id = (word_id + offset) / seq_len;
        //     const int task = ia3_tasks[batch_id];
        //     val            = val * ia3_weights[task * n + (id % n)];
        // }

        if (int8_mode != 2) {
            out[id] = val;
        }
        else {
            // reinterpret_cast<Packed_Int8_t*>(out)[id] =
            //     cuda_cast<Packed_Int8_t>(cuda_cast<Float_T>(val) * activation_out[0]);
        }
    }
}
// clang-format on

template<template<typename T> class Activation, typename T, typename BT>
void invokeGenericActivation(T*           out,
                             const BT*    bias,
                             const T*     gated_weights,
                             const BT*    gated_bias,
                             const int*   ia3_tasks,
                             const T*     ia3_weights,
                             const int    m,
                             const int    n,
                             const int    int8_mode,
                             const float* activation_in,
                             const float* activation_out,
                             const int*   padding_offset,
                             const int    seq_len,
                             cudaStream_t stream)
{
lvhan028's avatar
lvhan028 committed
258
259
    TM_LOG_DEBUG(__PRETTY_FUNCTION__);
    TM_LOG_DEBUG("invokeGenericActivation %d %d %d", m, n, seq_len);
Li Zhang's avatar
Li Zhang committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    using PT                   = typename packed_type<T>::type;
    constexpr int packed_elems = num_elems<PT>::value;
    using PBT                  = typename packed_as<BT, packed_elems>::type;

    const int n_threads = 512;

    dim3 block, grid;
    if (n / 4 / packed_elems <= n_threads) {
        block.x = n / 4 / packed_elems;
        grid.x  = m;
    }
    else {
        block.x = n_threads;
        grid.x  = ceil(m * n / double(n_threads));
    }
lvhan028's avatar
lvhan028 committed
275
    TM_LOG_DEBUG("%d %d", grid.x, block.x);
Li Zhang's avatar
Li Zhang committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    sync_check_cuda_error();
    generic_activation<Activation><<<grid, block, 0, stream>>>(reinterpret_cast<PT*>(out),
                                                               reinterpret_cast<const PBT*>(bias),
                                                               reinterpret_cast<const PT*>(gated_weights),
                                                               reinterpret_cast<const PBT*>(gated_bias),
                                                               ia3_tasks,
                                                               reinterpret_cast<const PT*>(ia3_weights),
                                                               int8_mode,
                                                               activation_in,
                                                               activation_out,
                                                               padding_offset,
                                                               seq_len,
                                                               m,
                                                               n / packed_elems);
    sync_check_cuda_error();
}

#define INSTANTIATE_GENERIC_ACTIVATION(Activation, T, BT)                                                              \
    template void invokeGenericActivation<Activation, T, BT>(T * out,                                                  \
                                                             const BT*    bias,                                        \
                                                             const T*     gated_weights,                               \
                                                             const BT*    gated_bias,                                  \
                                                             const int*   ia3_tasks,                                   \
                                                             const T*     ia3_weights,                                 \
                                                             const int    m,                                           \
                                                             const int    n,                                           \
                                                             const int    int8_mode,                                   \
                                                             const float* activation_in,                               \
                                                             const float* activation_out,                              \
                                                             const int*   padding_offset,                              \
                                                             const int    seq_len,                                     \
                                                             cudaStream_t stream);

309
310
311
312
313
314
315
316
317
318
319
// INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, float, float);
// INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, half, half);
// #ifdef ENABLE_BF16
// INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, __nv_bfloat16, __nv_bfloat16);
// #endif

// INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, float, float);
// INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, half, half);
// #ifdef ENABLE_BF16
// INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, __nv_bfloat16, __nv_bfloat16);
// #endif
Li Zhang's avatar
Li Zhang committed
320
321
322
323
324
325
326

INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, float, float);
INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half);
#ifdef ENABLE_BF16
INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16);
#endif

lvhan028's avatar
lvhan028 committed
327
}  // namespace turbomind