activation_kernels_impl.cuh 4.12 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
#include "utils.cuh"
#include "reduction_utils.cuh"

namespace vllm {

Muyang Li's avatar
Muyang Li committed
6
7
8
9
template<typename T>
__device__ __forceinline__ T silu(const T &x) {
    // x * sigmoid(x)
    return (T)(((float)x) / (1.0f + expf((float)-x)));
Zhekai Zhang's avatar
Zhekai Zhang committed
10
11
12
}

template<typename scalar_t>
Muyang Li's avatar
Muyang Li committed
13
14
15
__global__ void silu_and_mul_kernel(scalar_t *__restrict__ out,         // [..., d]
                                    const scalar_t *__restrict__ input, // [..., 2 * d]
                                    const int d) {
Zhekai Zhang's avatar
Zhekai Zhang committed
16

Muyang Li's avatar
Muyang Li committed
17
18
19
20
    const int token_idx        = blockIdx.x;
    const int64_t token_idx_d  = token_idx * int64_t(d);
    const int64_t token_idx_2d = token_idx_d * 2;
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
limm's avatar
limm committed
21
22
        const scalar_t x       = input[token_idx_2d + idx];
        const scalar_t y       = input[token_idx_2d + d + idx];
Muyang Li's avatar
Muyang Li committed
23
24
        out[token_idx_d + idx] = silu(x) * y;
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
}

// dequant int32 input, apply silu and mul, then per token quant to int8
Muyang Li's avatar
Muyang Li committed
28
29
30
31
32
33
34
35
template<typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out,          // [..., d]
                                                  const int32_t *__restrict__ input, // [..., 2 * d]
                                                  const int d,
                                                  const float scale_gate,
                                                  const float scale_up,
                                                  scale_type scale_out,             // [num_tokens]
                                                  float *__restrict__ tmp = nullptr // [num_tokens, d]
Zhekai Zhang's avatar
Zhekai Zhang committed
36
) {
Muyang Li's avatar
Muyang Li committed
37
38
39
40
41
42
    const int token_idx = blockIdx.x;
    if constexpr (use_per_token_quant) {
        float amax_val   = 0.0f;
        const float zero = 0.0f;

        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
limm's avatar
limm committed
43
44
            const float x            = (float)input[token_idx * 2 * d + idx] * scale_gate;
            const float y            = (float)input[token_idx * 2 * d + d + idx] * scale_up;
Muyang Li's avatar
Muyang Li committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
            float t                  = silu(x) * y;
            tmp[token_idx * d + idx] = t;
            t                        = t > zero ? t : -t;
            if (t > amax_val)
                amax_val = t;
        }

        __shared__ float s_amax;
        const float block_amax_val = blockReduceMax(amax_val);
        if (threadIdx.x == 0) {
            s_amax               = block_amax_val;
            scale_out[token_idx] = block_amax_val / 127.0f;
        }
        __syncthreads();

        float tmp_scale = 127.0f / s_amax;
        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
            out[token_idx * d + idx] = float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
        }
    } else {
        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
limm's avatar
limm committed
66
67
            const float x            = (float)input[token_idx * 2 * d + idx] * scale_gate;
            const float y            = (float)input[token_idx * 2 * d + d + idx] * scale_up;
Muyang Li's avatar
Muyang Li committed
68
69
            out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
70
71
72
73
74
75
76
    }
}
} // namespace vllm

namespace vllm {

// Element-wise activation kernel template.
Muyang Li's avatar
Muyang Li committed
77
78
79
80
81
82
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t &)>
__global__ void activation_kernel(scalar_t *__restrict__ out,         // [..., d]
                                  const scalar_t *__restrict__ input, // [..., d]
                                  const int d) {
    const int token_idx = blockIdx.x;
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
limm's avatar
limm committed
83
        const scalar_t x         = input[token_idx * d + idx];
Muyang Li's avatar
Muyang Li committed
84
85
        out[token_idx * d + idx] = ACT_FN(x);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
86
87
88
89
90
91
}

} // namespace vllm

namespace vllm {

Muyang Li's avatar
Muyang Li committed
92
93
94
95
96
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T &x) {
    const float x3 = (float)(x * x * x);
    const T t      = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
    return ((T)0.5) * x * (((T)1.0) + t);
Zhekai Zhang's avatar
Zhekai Zhang committed
97
98
}

Muyang Li's avatar
Muyang Li committed
99
template<typename T>
Zhekai Zhang's avatar
Zhekai Zhang committed
100
__device__ __forceinline__ T gelu_fast_kernel(const T &x) {
Muyang Li's avatar
Muyang Li committed
101
102
103
    const float f = (float)x;
    const T t     = (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
    return ((T)0.5) * x * (((T)1.0) + t);
Zhekai Zhang's avatar
Zhekai Zhang committed
104
105
106
}

} // namespace vllm