#include #include #include #include #include #include #include #include #include // Forward declarations for fallback to existing vLLM kernels. void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); namespace vllm { template __device__ __forceinline__ T warp_reduce_sum_xor(T val) { #pragma unroll for (int mask = WIDTH / 2; mask > 0; mask >>= 1) { val += __shfl_xor(val, mask); } return val; } template __device__ __forceinline__ T_ACC apply_residual_and_calc_sq( scalar_t* r_data_low, scalar_t* r_data_high, scalar_t* res_head_ptr, int offset_low, int offset_high) { using LoadT = at::native::memory::aligned_vector; if constexpr (HAS_RESIDUAL) { scalar_t r_res_low[VEC_SIZE]; scalar_t r_res_high[VEC_SIZE]; *(LoadT*)r_res_low = *(LoadT*)(res_head_ptr + offset_low); *(LoadT*)r_res_high = *(LoadT*)(res_head_ptr + offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { r_res_low[i] = r_res_low[i] + r_data_low[i]; r_res_high[i] = r_res_high[i] + r_data_high[i]; r_data_low[i] = r_res_low[i]; r_data_high[i] = r_res_high[i]; } *(LoadT*)(res_head_ptr + offset_low) = *(LoadT*)r_res_low; *(LoadT*)(res_head_ptr + offset_high) = *(LoadT*)r_res_high; } T_ACC local_sum_sq = 0; #pragma unroll VEC_SIZE for (int i = 0; i < VEC_SIZE; ++i) { T_ACC low = static_cast(r_data_low[i]); T_ACC high = static_cast(r_data_high[i]); local_sum_sq += low * low; local_sum_sq += high * high; } return local_sum_sq; } #define DISPATCH_BOOL(VAL, NAME, ...) \ if (VAL) { \ constexpr bool NAME = true; \ __VA_ARGS__(); \ } else { \ constexpr bool NAME = false; \ __VA_ARGS__(); \ } template __global__ void opt_rms_rope_qwen3( const int64_t* __restrict__ positions, scalar_t* __restrict__ query, scalar_t* __restrict__ key, const scalar_t* __restrict__ cos_sin_cache, const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride_q, const int64_t head_stride_k, const scalar_t* __restrict__ gamma_q, const scalar_t* __restrict__ gamma_k, scalar_t* residual_q, scalar_t* residual_k, const scalar_t eps, const int num_tokens, const int num_heads, const int num_kv_heads, const int threads_per_token, const int tokens_per_block) { extern __shared__ char smem_buffer[]; scalar_t* s_cos_sin_base = reinterpret_cast(smem_buffer); constexpr int HEAD_SIZE = 128; constexpr int HALF_ROT = 64; const int tid = threadIdx.x; const int local_token_idx = tid / threads_per_token; const int lane = tid % threads_per_token; if (local_token_idx >= tokens_per_block) return; const int global_token_idx = blockIdx.x * tokens_per_block + local_token_idx; if (global_token_idx >= num_tokens) return; scalar_t* my_s_cos_sin = s_cos_sin_base + local_token_idx * HEAD_SIZE; const int64_t pos = positions[global_token_idx]; for (int i = lane; i < HEAD_SIZE; i += threads_per_token) { my_s_cos_sin[i] = cos_sin_cache[pos * HEAD_SIZE + i]; } __syncthreads(); const int q_boundary = num_heads * THREAD_PER_HEAD; if (lane < q_boundary) { const int q_head_idx = lane / THREAD_PER_HEAD; const int q_lane_in_head = lane % THREAD_PER_HEAD; scalar_t* q_head_ptr = query + global_token_idx * query_stride + q_head_idx * head_stride_q; scalar_t* res_q_head_ptr = HAS_RESIDUAL ? (residual_q + global_token_idx * query_stride + q_head_idx * head_stride_q) : nullptr; using LoadT = at::native::memory::aligned_vector; scalar_t r_q_low[VEC_SIZE]; scalar_t r_q_high[VEC_SIZE]; const int offset_low = q_lane_in_head * VEC_SIZE; const int offset_high = HALF_ROT + q_lane_in_head * VEC_SIZE; *(LoadT*)r_q_low = *(LoadT*)(q_head_ptr + offset_low); *(LoadT*)r_q_high = *(LoadT*)(q_head_ptr + offset_high); T_ACC sum_sq = apply_residual_and_calc_sq( r_q_low, r_q_high, res_q_head_ptr, offset_low, offset_high); sum_sq = warp_reduce_sum_xor(sum_sq); const T_ACC inv_rms = c10::cuda::compat::rsqrt(sum_sq / HEAD_SIZE + static_cast(eps)); const scalar_t* cache_ptr = my_s_cos_sin; if constexpr (IS_NEOX) { scalar_t r_cos_low[VEC_SIZE], r_sin_low[VEC_SIZE]; *(LoadT*)r_cos_low = *(LoadT*)(cache_ptr + offset_low); *(LoadT*)r_sin_low = *(LoadT*)(cache_ptr + rot_dim / 2 + offset_low); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { r_q_low[i] = static_cast(r_q_low[i]) * inv_rms * static_cast(gamma_q[offset_low + i]); r_q_high[i] = static_cast(r_q_high[i]) * inv_rms * static_cast(gamma_q[offset_high + i]); const scalar_t q_l = r_q_low[i]; const scalar_t q_h = r_q_high[i]; const scalar_t c = r_cos_low[i]; const scalar_t s = r_sin_low[i]; r_q_low[i] = q_l * c - q_h * s; r_q_high[i] = q_l * s + q_h * c; } } else { using LoadCacheT = at::native::memory::aligned_vector; scalar_t c_low[VEC_SIZE / 2], s_low[VEC_SIZE / 2]; scalar_t c_high[VEC_SIZE / 2], s_high[VEC_SIZE / 2]; const int cache_offset_low = offset_low / 2; const int cache_offset_high = offset_high / 2; *(LoadCacheT*)c_low = *(LoadCacheT*)(cache_ptr + cache_offset_low); *(LoadCacheT*)s_low = *(LoadCacheT*)(cache_ptr + rot_dim / 2 + cache_offset_low); *(LoadCacheT*)c_high = *(LoadCacheT*)(cache_ptr + cache_offset_high); *(LoadCacheT*)s_high = *(LoadCacheT*)(cache_ptr + rot_dim / 2 + cache_offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; i += 2) { const int c_idx = i / 2; r_q_low[i] = static_cast(r_q_low[i]) * inv_rms * static_cast(gamma_q[offset_low + i]); r_q_low[i + 1] = static_cast(r_q_low[i + 1]) * inv_rms * static_cast(gamma_q[offset_low + i + 1]); const scalar_t q0 = r_q_low[i]; const scalar_t q1 = r_q_low[i + 1]; const scalar_t c = c_low[c_idx]; const scalar_t s = s_low[c_idx]; r_q_low[i] = q0 * c - q1 * s; r_q_low[i + 1] = q1 * c + q0 * s; r_q_high[i] = static_cast(r_q_high[i]) * inv_rms * static_cast(gamma_q[offset_high + i]); r_q_high[i + 1] = static_cast(r_q_high[i + 1]) * inv_rms * static_cast(gamma_q[offset_high + i + 1]); const scalar_t qh0 = r_q_high[i]; const scalar_t qh1 = r_q_high[i + 1]; const scalar_t ch = c_high[c_idx]; const scalar_t sh = s_high[c_idx]; r_q_high[i] = qh0 * ch - qh1 * sh; r_q_high[i + 1] = qh1 * ch + qh0 * sh; } } *(LoadT*)(q_head_ptr + offset_low) = *(LoadT*)r_q_low; *(LoadT*)(q_head_ptr + offset_high) = *(LoadT*)r_q_high; } const int total_threads_needed = (num_heads + num_kv_heads) * THREAD_PER_HEAD; if (lane >= q_boundary && lane < total_threads_needed && key != nullptr) { const int k_lane_abs = lane - q_boundary; const int kv_head_idx = k_lane_abs / THREAD_PER_HEAD; const int k_lane_in_head = k_lane_abs % THREAD_PER_HEAD; scalar_t* k_head_ptr = key + global_token_idx * key_stride + kv_head_idx * head_stride_k; scalar_t* res_k_head_ptr = HAS_RESIDUAL ? (residual_k + global_token_idx * key_stride + kv_head_idx * head_stride_k) : nullptr; using LoadTK = at::native::memory::aligned_vector; scalar_t r_k_low[VEC_SIZE]; scalar_t r_k_high[VEC_SIZE]; const int offset_low = k_lane_in_head * VEC_SIZE; const int offset_high = HALF_ROT + k_lane_in_head * VEC_SIZE; *(LoadTK*)r_k_low = *(LoadTK*)(k_head_ptr + offset_low); *(LoadTK*)r_k_high = *(LoadTK*)(k_head_ptr + offset_high); T_ACC sum_sq_k = apply_residual_and_calc_sq( r_k_low, r_k_high, res_k_head_ptr, offset_low, offset_high); sum_sq_k = warp_reduce_sum_xor(sum_sq_k); const T_ACC inv_rms_k = c10::cuda::compat::rsqrt(sum_sq_k / HEAD_SIZE + static_cast(eps)); const scalar_t* cache_ptr_k = my_s_cos_sin; if constexpr (IS_NEOX) { scalar_t r_cos_low[VEC_SIZE], r_sin_low[VEC_SIZE]; scalar_t r_gamma_k_low[VEC_SIZE], r_gamma_k_high[VEC_SIZE]; *(LoadTK*)r_cos_low = *(LoadTK*)(cache_ptr_k + offset_low); *(LoadTK*)r_sin_low = *(LoadTK*)(cache_ptr_k + rot_dim / 2 + offset_low); *(LoadTK*)r_gamma_k_low = *(LoadTK*)(gamma_k + offset_low); *(LoadTK*)r_gamma_k_high = *(LoadTK*)(gamma_k + offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { r_k_low[i] = static_cast(r_k_low[i]) * inv_rms_k * static_cast(r_gamma_k_low[i]); r_k_high[i] = static_cast(r_k_high[i]) * inv_rms_k * static_cast(r_gamma_k_high[i]); const scalar_t k_l = r_k_low[i]; const scalar_t k_h = r_k_high[i]; const scalar_t c = r_cos_low[i]; const scalar_t s = r_sin_low[i]; r_k_low[i] = k_l * c - k_h * s; r_k_high[i] = k_l * s + k_h * c; } } else { using LoadCacheTK = at::native::memory::aligned_vector; scalar_t r_cos_low[VEC_SIZE / 2], r_sin_low[VEC_SIZE / 2]; scalar_t r_cos_high[VEC_SIZE / 2], r_sin_high[VEC_SIZE / 2]; const int cache_offset_low = offset_low / 2; const int cache_offset_high = offset_high / 2; *(LoadCacheTK*)r_cos_low = *(LoadCacheTK*)(cache_ptr_k + cache_offset_low); *(LoadCacheTK*)r_sin_low = *(LoadCacheTK*)(cache_ptr_k + rot_dim / 2 + cache_offset_low); *(LoadCacheTK*)r_cos_high = *(LoadCacheTK*)(cache_ptr_k + cache_offset_high); *(LoadCacheTK*)r_sin_high = *(LoadCacheTK*)(cache_ptr_k + rot_dim / 2 + cache_offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; i += 2) { const int c_idx = i / 2; r_k_low[i] = static_cast(r_k_low[i]) * inv_rms_k * static_cast(gamma_k[offset_low + i]); r_k_low[i + 1] = static_cast(r_k_low[i + 1]) * inv_rms_k * static_cast(gamma_k[offset_low + i + 1]); const scalar_t k0 = r_k_low[i]; const scalar_t k1 = r_k_low[i + 1]; const scalar_t c = r_cos_low[c_idx]; const scalar_t s = r_sin_low[c_idx]; r_k_low[i] = k0 * c - k1 * s; r_k_low[i + 1] = k1 * c + k0 * s; r_k_high[i] = static_cast(r_k_high[i]) * inv_rms_k * static_cast(gamma_k[offset_high + i]); r_k_high[i + 1] = static_cast(r_k_high[i + 1]) * inv_rms_k * static_cast(gamma_k[offset_high + i + 1]); const scalar_t kh0 = r_k_high[i]; const scalar_t kh1 = r_k_high[i + 1]; const scalar_t ch = r_cos_high[c_idx]; const scalar_t sh = r_sin_high[c_idx]; r_k_high[i] = kh0 * ch - kh1 * sh; r_k_high[i + 1] = kh1 * ch + kh0 * sh; } } *(LoadTK*)(k_head_ptr + offset_low) = *(LoadTK*)r_k_low; *(LoadTK*)(k_head_ptr + offset_high) = *(LoadTK*)r_k_high; } } template void launch_opt_rms_rope( const int64_t* positions, scalar_t* query, scalar_t* key, const scalar_t* cos_sin_cache, const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride_q, const int64_t head_stride_k, const scalar_t* gamma_q, const scalar_t* gamma_k, scalar_t* residual_q_ptr, scalar_t* residual_k_ptr, const scalar_t eps, const int num_tokens, const bool is_neox, const int num_heads, const int num_kv_heads, const cudaStream_t stream) { const bool has_residual = (residual_q_ptr != nullptr && residual_k_ptr != nullptr); constexpr int THREAD_PER_HEAD = 8; constexpr int VEC = 8; const int threads_per_token = (num_heads + num_kv_heads) * THREAD_PER_HEAD; // Keep the same launch heuristic as the original kernel. const int target_block_size = 256; int tokens_per_block = target_block_size / threads_per_token; if (tokens_per_block < 1) tokens_per_block = 1; const int actual_block_size = tokens_per_block * threads_per_token; const int grid_size = (num_tokens + tokens_per_block - 1) / tokens_per_block; const size_t smem_size = tokens_per_block * 128 * sizeof(scalar_t); DISPATCH_BOOL(has_residual, HAS_RESIDUAL_CONST, [&] { DISPATCH_BOOL(is_neox, IS_NEOX_CONST, [&] { opt_rms_rope_qwen3 <<>>( positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, head_stride_q, head_stride_k, gamma_q, gamma_k, residual_q_ptr, residual_k_ptr, eps, num_tokens, num_heads, num_kv_heads, threads_per_token, tokens_per_block); }); }); } } // namespace vllm void rms_rotary_embedding_fuse( torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox, torch::Tensor& weight_q, torch::Tensor& weight_k, std::optional residual_q, std::optional residual_k, double epsilon) { // Basic validation (mirrors rotary_embedding + layernorm checks). const int64_t num_tokens = positions.numel(); const int positions_ndim = positions.dim(); TORCH_CHECK(positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { TORCH_CHECK(query.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)), "query, key and positions must have the same number of tokens"); } else { TORCH_CHECK( query.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } TORCH_CHECK(query.is_cuda(), "query must be CUDA"); TORCH_CHECK(!key.has_value() || key->is_cuda(), "key must be CUDA"); TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); TORCH_CHECK(positions.is_cuda(), "positions must be CUDA"); TORCH_CHECK(weight_q.is_cuda() && weight_k.is_cuda(), "weights must be CUDA"); TORCH_CHECK(query.is_contiguous(), "query must be contiguous"); TORCH_CHECK(!key.has_value() || key->is_contiguous(), "key must be contiguous"); TORCH_CHECK(cos_sin_cache.is_contiguous(), "cos_sin_cache must be contiguous"); TORCH_CHECK(weight_q.is_contiguous() && weight_k.is_contiguous(), "weights must be contiguous"); TORCH_CHECK(positions.scalar_type() == at::kLong, "positions must be int64"); TORCH_CHECK(query.scalar_type() == cos_sin_cache.scalar_type(), "cos_sin_cache must have same dtype as query"); TORCH_CHECK(weight_q.scalar_type() == query.scalar_type() && weight_k.scalar_type() == query.scalar_type(), "weights must have same dtype as query"); TORCH_CHECK(!key.has_value() || key->scalar_type() == query.scalar_type(), "key must have same dtype as query"); if (residual_q.has_value() || residual_k.has_value()) { TORCH_CHECK(residual_q.has_value() && residual_k.has_value(), "residual_q and residual_k must be both provided or both None"); TORCH_CHECK(residual_q->is_cuda() && residual_k->is_cuda(), "residual tensors must be CUDA"); TORCH_CHECK(residual_q->is_contiguous() && residual_k->is_contiguous(), "residual tensors must be contiguous"); TORCH_CHECK(residual_q->scalar_type() == query.scalar_type() && residual_k->scalar_type() == query.scalar_type(), "residual tensors must have same dtype as query"); } const int query_hidden_size = query.numel() / num_tokens; const int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(!key.has_value() || (key_hidden_size % head_size == 0)); const int num_heads = query_hidden_size / head_size; const int num_kv_heads = key.has_value() ? (key_hidden_size / head_size) : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); const int rot_dim = cos_sin_cache.size(1); const int seq_dim_idx = positions_ndim - 1; const int64_t query_stride = query.stride(seq_dim_idx); const int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; const int query_ndim = query.dim(); const int64_t head_stride_q = (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; const int64_t head_stride_k = (key.has_value() && key->dim() == positions_ndim + 2) ? key->stride(-2) : head_size; const bool has_residual = residual_q.has_value() && residual_k.has_value(); const bool supports_qwen3_opt = (key.has_value() && head_size == 128 && rot_dim == 128 && (num_heads + num_kv_heads) <= 128 && weight_q.numel() == 128 && weight_k.numel() == 128); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (supports_qwen3_opt) { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, query.scalar_type(), "vllm_rms_rotary_embedding_fuse_qwen3", [&] { using T_ACC = at::acc_type; scalar_t* res_q_ptr = has_residual ? residual_q->data_ptr() : nullptr; scalar_t* res_k_ptr = has_residual ? residual_k->data_ptr() : nullptr; vllm::launch_opt_rms_rope( positions.data_ptr(), query.data_ptr(), key->data_ptr(), cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, head_stride_q, head_stride_k, weight_q.data_ptr(), weight_k.data_ptr(), res_q_ptr, res_k_ptr, static_cast(epsilon), static_cast(num_tokens), is_neox, num_heads, num_kv_heads, stream); }); return; } // Fallback: use existing kernels (still removes lightop dependency). // Apply per-head RMSNorm to Q/K and then call the existing RoPE kernel. { TORCH_CHECK(weight_q.numel() == head_size && weight_k.numel() == head_size, "weight_q/weight_k must have shape [head_size]"); auto q_heads = query.view({num_tokens * num_heads, head_size}); if (has_residual) { auto rq_heads = residual_q->view({num_tokens * num_heads, head_size}); fused_add_rms_norm_opt(q_heads, rq_heads, weight_q, epsilon); } else { rms_norm_opt(q_heads, q_heads, weight_q, epsilon); } if (key.has_value()) { auto k_heads = key->view({num_tokens * num_kv_heads, head_size}); if (has_residual) { auto rk_heads = residual_k->view({num_tokens * num_kv_heads, head_size}); fused_add_rms_norm_opt(k_heads, rk_heads, weight_k, epsilon); } else { rms_norm_opt(k_heads, k_heads, weight_k, epsilon); } } } rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox); }