#include "utils.cuh" #include "reduction_utils.cuh" namespace vllm { template __device__ __forceinline__ T silu(const T &x) { // x * sigmoid(x) return (T)(((float)x) / (1.0f + expf((float)-x))); } template __global__ void silu_and_mul_kernel(scalar_t *__restrict__ out, // [..., d] const scalar_t *__restrict__ input, // [..., 2 * d] const int d) { const int token_idx = blockIdx.x; const int64_t token_idx_d = token_idx * int64_t(d); const int64_t token_idx_2d = token_idx_d * 2; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx_2d + idx]); const scalar_t y = __ldg(&input[token_idx_2d + d + idx]); out[token_idx_d + idx] = silu(x) * y; } } // dequant int32 input, apply silu and mul, then per token quant to int8 template __global__ void dequant_silu_and_mul_quant_kernel(int8_t *__restrict__ out, // [..., d] const int32_t *__restrict__ input, // [..., 2 * d] const int d, const float scale_gate, const float scale_up, scale_type scale_out, // [num_tokens] float *__restrict__ tmp = nullptr // [num_tokens, d] ) { const int token_idx = blockIdx.x; if constexpr (use_per_token_quant) { float amax_val = 0.0f; const float zero = 0.0f; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate; const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up; float t = silu(x) * y; tmp[token_idx * d + idx] = t; t = t > zero ? t : -t; if (t > amax_val) amax_val = t; } __shared__ float s_amax; const float block_amax_val = blockReduceMax(amax_val); if (threadIdx.x == 0) { s_amax = block_amax_val; scale_out[token_idx] = block_amax_val / 127.0f; } __syncthreads(); float tmp_scale = 127.0f / s_amax; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { out[token_idx * d + idx] = float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]); } } else { for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const float x = (float)__ldg(&input[token_idx * 2 * d + idx]) * scale_gate; const float y = (float)__ldg(&input[token_idx * 2 * d + d + idx]) * scale_up; out[token_idx * d + idx] = float_to_int8_rn(silu(x) * y / scale_out); } } } } // namespace vllm namespace vllm { // Element-wise activation kernel template. template __global__ void activation_kernel(scalar_t *__restrict__ out, // [..., d] const scalar_t *__restrict__ input, // [..., d] const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } } // namespace vllm namespace vllm { template __device__ __forceinline__ T gelu_new_kernel(const T &x) { const float x3 = (float)(x * x * x); const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); return ((T)0.5) * x * (((T)1.0) + t); } template __device__ __forceinline__ T gelu_fast_kernel(const T &x) { const float f = (float)x; const T t = (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); return ((T)0.5) * x * (((T)1.0) + t); } } // namespace vllm