activation_kernels_impl.cuh 4.17 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
#include "utils.cuh"
#include "reduction_utils.cuh"

namespace vllm {

Muyang Li's avatar
Muyang Li committed
6
7
8
9
template<typename T>
__device__ __forceinline__ T silu(const T &x) {
    // x * sigmoid(x)
    return (T)(((float)x) / (1.0f + expf((float)-x)));
Zhekai Zhang's avatar
Zhekai Zhang committed
10
11
12
}

template<typename scalar_t>
Muyang Li's avatar
Muyang Li committed
13
14
15
__global__ void silu_and_mul_kernel(scalar_t *__restrict__ out,         // [..., d]
                                    const scalar_t *__restrict__ input, // [..., 2 * d]
                                    const int d) {
Zhekai Zhang's avatar
Zhekai Zhang committed
16

Muyang Li's avatar
Muyang Li committed
17
18
19
20
21
22
23
24
    const int token_idx        = blockIdx.x;
    const int64_t token_idx_d  = token_idx * int64_t(d);
    const int64_t token_idx_2d = token_idx_d * 2;
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
        const scalar_t x       = __ldg(&input[token_idx_2d + idx]);
        const scalar_t y       = __ldg(&input[token_idx_2d + d + idx]);
        out[token_idx_d + idx] = silu(x) * y;
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
25
26
27
}

// dequant int32 input, apply silu and mul, then per token quant to int8
Muyang Li's avatar
Muyang Li committed
28
29
30
31
32
33
34
35
template<typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out,          // [..., d]
                                                  const int32_t *__restrict__ input, // [..., 2 * d]
                                                  const int d,
                                                  const float scale_gate,
                                                  const float scale_up,
                                                  scale_type scale_out,             // [num_tokens]
                                                  float *__restrict__ tmp = nullptr // [num_tokens, d]
Zhekai Zhang's avatar
Zhekai Zhang committed
36
) {
Muyang Li's avatar
Muyang Li committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    const int token_idx = blockIdx.x;
    if constexpr (use_per_token_quant) {
        float amax_val   = 0.0f;
        const float zero = 0.0f;

        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
            const float x            = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
            const float y            = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
            float t                  = silu(x) * y;
            tmp[token_idx * d + idx] = t;
            t                        = t > zero ? t : -t;
            if (t > amax_val)
                amax_val = t;
        }

        __shared__ float s_amax;
        const float block_amax_val = blockReduceMax(amax_val);
        if (threadIdx.x == 0) {
            s_amax               = block_amax_val;
            scale_out[token_idx] = block_amax_val / 127.0f;
        }
        __syncthreads();

        float tmp_scale = 127.0f / s_amax;
        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
            out[token_idx * d + idx] = float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
        }
    } else {
        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
            const float x            = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
            const float y            = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
            out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
70
71
72
73
74
75
76
    }
}
} // namespace vllm

namespace vllm {

// Element-wise activation kernel template.
Muyang Li's avatar
Muyang Li committed
77
78
79
80
81
82
83
84
85
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t &)>
__global__ void activation_kernel(scalar_t *__restrict__ out,         // [..., d]
                                  const scalar_t *__restrict__ input, // [..., d]
                                  const int d) {
    const int token_idx = blockIdx.x;
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
        const scalar_t x         = __ldg(&input[token_idx * d + idx]);
        out[token_idx * d + idx] = ACT_FN(x);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
86
87
88
89
90
91
}

} // namespace vllm

namespace vllm {

Muyang Li's avatar
Muyang Li committed
92
93
94
95
96
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T &x) {
    const float x3 = (float)(x * x * x);
    const T t      = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
    return ((T)0.5) * x * (((T)1.0) + t);
Zhekai Zhang's avatar
Zhekai Zhang committed
97
98
}

Muyang Li's avatar
Muyang Li committed
99
template<typename T>
Zhekai Zhang's avatar
Zhekai Zhang committed
100
__device__ __forceinline__ T gelu_fast_kernel(const T &x) {
Muyang Li's avatar
Muyang Li committed
101
102
103
    const float f = (float)x;
    const T t     = (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
    return ((T)0.5) * x * (((T)1.0) + t);
Zhekai Zhang's avatar
Zhekai Zhang committed
104
105
106
}

} // namespace vllm