activation_kernels_impl.cuh 4.2 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
#include "hip/hip_runtime.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
2
3
4
5
6
#include "utils.cuh"
#include "reduction_utils.cuh"

namespace vllm {

Muyang Li's avatar
Muyang Li committed
7
8
9
10
template<typename T>
__device__ __forceinline__ T silu(const T &x) {
    // x * sigmoid(x)
    return (T)(((float)x) / (1.0f + expf((float)-x)));
Zhekai Zhang's avatar
Zhekai Zhang committed
11
12
13
}

template<typename scalar_t>
Muyang Li's avatar
Muyang Li committed
14
15
16
__global__ void silu_and_mul_kernel(scalar_t *__restrict__ out,         // [..., d]
                                    const scalar_t *__restrict__ input, // [..., 2 * d]
                                    const int d) {
Zhekai Zhang's avatar
Zhekai Zhang committed
17

Muyang Li's avatar
Muyang Li committed
18
19
20
21
22
23
24
25
    const int token_idx        = blockIdx.x;
    const int64_t token_idx_d  = token_idx * int64_t(d);
    const int64_t token_idx_2d = token_idx_d * 2;
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
        const scalar_t x       = __ldg(&input[token_idx_2d + idx]);
        const scalar_t y       = __ldg(&input[token_idx_2d + d + idx]);
        out[token_idx_d + idx] = silu(x) * y;
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
26
27
28
}

// dequant int32 input, apply silu and mul, then per token quant to int8
Muyang Li's avatar
Muyang Li committed
29
30
31
32
33
34
35
36
template<typename scale_type, bool use_per_token_quant>
__global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out,          // [..., d]
                                                  const int32_t *__restrict__ input, // [..., 2 * d]
                                                  const int d,
                                                  const float scale_gate,
                                                  const float scale_up,
                                                  scale_type scale_out,             // [num_tokens]
                                                  float *__restrict__ tmp = nullptr // [num_tokens, d]
Zhekai Zhang's avatar
Zhekai Zhang committed
37
) {
Muyang Li's avatar
Muyang Li committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    const int token_idx = blockIdx.x;
    if constexpr (use_per_token_quant) {
        float amax_val   = 0.0f;
        const float zero = 0.0f;

        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
            const float x            = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
            const float y            = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
            float t                  = silu(x) * y;
            tmp[token_idx * d + idx] = t;
            t                        = t > zero ? t : -t;
            if (t > amax_val)
                amax_val = t;
        }

        __shared__ float s_amax;
        const float block_amax_val = blockReduceMax(amax_val);
        if (threadIdx.x == 0) {
            s_amax               = block_amax_val;
            scale_out[token_idx] = block_amax_val / 127.0f;
        }
        __syncthreads();

        float tmp_scale = 127.0f / s_amax;
        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
            out[token_idx * d + idx] = float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
        }
    } else {
        for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
            const float x            = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate;
            const float y            = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up;
            out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out);
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
71
72
73
74
75
76
77
    }
}
} // namespace vllm

namespace vllm {

// Element-wise activation kernel template.
Muyang Li's avatar
Muyang Li committed
78
79
80
81
82
83
84
85
86
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t &)>
__global__ void activation_kernel(scalar_t *__restrict__ out,         // [..., d]
                                  const scalar_t *__restrict__ input, // [..., d]
                                  const int d) {
    const int token_idx = blockIdx.x;
    for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
        const scalar_t x         = __ldg(&input[token_idx * d + idx]);
        out[token_idx * d + idx] = ACT_FN(x);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
87
88
89
90
91
92
}

} // namespace vllm

namespace vllm {

Muyang Li's avatar
Muyang Li committed
93
94
95
96
97
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T &x) {
    const float x3 = (float)(x * x * x);
    const T t      = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
    return ((T)0.5) * x * (((T)1.0) + t);
Zhekai Zhang's avatar
Zhekai Zhang committed
98
99
}

Muyang Li's avatar
Muyang Li committed
100
template<typename T>
Zhekai Zhang's avatar
Zhekai Zhang committed
101
__device__ __forceinline__ T gelu_fast_kernel(const T &x) {
Muyang Li's avatar
Muyang Li committed
102
103
104
    const float f = (float)x;
    const T t     = (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
    return ((T)0.5) * x * (((T)1.0) + t);
Zhekai Zhang's avatar
Zhekai Zhang committed
105
106
107
}

} // namespace vllm