#include #include #include #include #include #include #include #include #include // Forward declarations for fallback to existing vLLM kernels. void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); namespace vllm { template __device__ __forceinline__ T warp_reduce_sum_xor(T val) { #pragma unroll for (int mask = WIDTH / 2; mask > 0; mask >>= 1) { val += __shfl_xor(val, mask); } return val; } template __device__ __forceinline__ T_ACC apply_residual_and_calc_sq( scalar_t* r_data_low, scalar_t* r_data_high, scalar_t* res_head_ptr, int offset_low, int offset_high) { using LoadT = at::native::memory::aligned_vector; if constexpr (HAS_RESIDUAL) { scalar_t r_res_low[VEC_SIZE]; scalar_t r_res_high[VEC_SIZE]; *(LoadT*)r_res_low = *(LoadT*)(res_head_ptr + offset_low); *(LoadT*)r_res_high = *(LoadT*)(res_head_ptr + offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { r_res_low[i] = r_res_low[i] + r_data_low[i]; r_res_high[i] = r_res_high[i] + r_data_high[i]; r_data_low[i] = r_res_low[i]; r_data_high[i] = r_res_high[i]; } *(LoadT*)(res_head_ptr + offset_low) = *(LoadT*)r_res_low; *(LoadT*)(res_head_ptr + offset_high) = *(LoadT*)r_res_high; } T_ACC local_sum_sq = 0; #pragma unroll VEC_SIZE for (int i = 0; i < VEC_SIZE; ++i) { T_ACC low = static_cast(r_data_low[i]); T_ACC high = static_cast(r_data_high[i]); local_sum_sq += low * low; local_sum_sq += high * high; } return local_sum_sq; } #define DISPATCH_BOOL(VAL, NAME, ...) \ if (VAL) { \ constexpr bool NAME = true; \ __VA_ARGS__(); \ } else { \ constexpr bool NAME = false; \ __VA_ARGS__(); \ } template __global__ void opt_rms_rope_qwen3( const int64_t* __restrict__ positions, scalar_t* __restrict__ query, scalar_t* __restrict__ key, const scalar_t* __restrict__ cos_sin_cache, const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride_q, const int64_t head_stride_k, const scalar_t* __restrict__ gamma_q, const scalar_t* __restrict__ gamma_k, scalar_t* residual_q, scalar_t* residual_k, const scalar_t eps, const int num_tokens, const int num_heads, const int num_kv_heads, const int threads_per_token, const int tokens_per_block) { extern __shared__ char smem_buffer[]; scalar_t* s_cos_sin_base = reinterpret_cast(smem_buffer); constexpr int HEAD_SIZE = 128; constexpr int HALF_ROT = 64; const int tid = threadIdx.x; const int local_token_idx = tid / threads_per_token; const int lane = tid % threads_per_token; if (local_token_idx >= tokens_per_block) return; const int global_token_idx = blockIdx.x * tokens_per_block + local_token_idx; if (global_token_idx >= num_tokens) return; scalar_t* my_s_cos_sin = s_cos_sin_base + local_token_idx * HEAD_SIZE; const int64_t pos = positions[global_token_idx]; for (int i = lane; i < HEAD_SIZE; i += threads_per_token) { my_s_cos_sin[i] = cos_sin_cache[pos * HEAD_SIZE + i]; } __syncthreads(); const int q_boundary = num_heads * THREAD_PER_HEAD; if (lane < q_boundary) { const int q_head_idx = lane / THREAD_PER_HEAD; const int q_lane_in_head = lane % THREAD_PER_HEAD; scalar_t* q_head_ptr = query + global_token_idx * query_stride + q_head_idx * head_stride_q; scalar_t* res_q_head_ptr = HAS_RESIDUAL ? (residual_q + global_token_idx * query_stride + q_head_idx * head_stride_q) : nullptr; using LoadT = at::native::memory::aligned_vector; scalar_t r_q_low[VEC_SIZE]; scalar_t r_q_high[VEC_SIZE]; const int offset_low = q_lane_in_head * VEC_SIZE; const int offset_high = HALF_ROT + q_lane_in_head * VEC_SIZE; *(LoadT*)r_q_low = *(LoadT*)(q_head_ptr + offset_low); *(LoadT*)r_q_high = *(LoadT*)(q_head_ptr + offset_high); T_ACC sum_sq = apply_residual_and_calc_sq( r_q_low, r_q_high, res_q_head_ptr, offset_low, offset_high); sum_sq = warp_reduce_sum_xor(sum_sq); const T_ACC inv_rms = c10::cuda::compat::rsqrt(sum_sq / HEAD_SIZE + static_cast(eps)); const scalar_t* cache_ptr = my_s_cos_sin; if constexpr (IS_NEOX) { scalar_t r_cos_low[VEC_SIZE], r_sin_low[VEC_SIZE]; *(LoadT*)r_cos_low = *(LoadT*)(cache_ptr + offset_low); *(LoadT*)r_sin_low = *(LoadT*)(cache_ptr + rot_dim / 2 + offset_low); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { r_q_low[i] = static_cast(r_q_low[i]) * inv_rms * static_cast(gamma_q[offset_low + i]); r_q_high[i] = static_cast(r_q_high[i]) * inv_rms * static_cast(gamma_q[offset_high + i]); const scalar_t q_l = r_q_low[i]; const scalar_t q_h = r_q_high[i]; const scalar_t c = r_cos_low[i]; const scalar_t s = r_sin_low[i]; r_q_low[i] = q_l * c - q_h * s; r_q_high[i] = q_l * s + q_h * c; } } else { using LoadCacheT = at::native::memory::aligned_vector; scalar_t c_low[VEC_SIZE / 2], s_low[VEC_SIZE / 2]; scalar_t c_high[VEC_SIZE / 2], s_high[VEC_SIZE / 2]; const int cache_offset_low = offset_low / 2; const int cache_offset_high = offset_high / 2; *(LoadCacheT*)c_low = *(LoadCacheT*)(cache_ptr + cache_offset_low); *(LoadCacheT*)s_low = *(LoadCacheT*)(cache_ptr + rot_dim / 2 + cache_offset_low); *(LoadCacheT*)c_high = *(LoadCacheT*)(cache_ptr + cache_offset_high); *(LoadCacheT*)s_high = *(LoadCacheT*)(cache_ptr + rot_dim / 2 + cache_offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; i += 2) { const int c_idx = i / 2; r_q_low[i] = static_cast(r_q_low[i]) * inv_rms * static_cast(gamma_q[offset_low + i]); r_q_low[i + 1] = static_cast(r_q_low[i + 1]) * inv_rms * static_cast(gamma_q[offset_low + i + 1]); const scalar_t q0 = r_q_low[i]; const scalar_t q1 = r_q_low[i + 1]; const scalar_t c = c_low[c_idx]; const scalar_t s = s_low[c_idx]; r_q_low[i] = q0 * c - q1 * s; r_q_low[i + 1] = q1 * c + q0 * s; r_q_high[i] = static_cast(r_q_high[i]) * inv_rms * static_cast(gamma_q[offset_high + i]); r_q_high[i + 1] = static_cast(r_q_high[i + 1]) * inv_rms * static_cast(gamma_q[offset_high + i + 1]); const scalar_t qh0 = r_q_high[i]; const scalar_t qh1 = r_q_high[i + 1]; const scalar_t ch = c_high[c_idx]; const scalar_t sh = s_high[c_idx]; r_q_high[i] = qh0 * ch - qh1 * sh; r_q_high[i + 1] = qh1 * ch + qh0 * sh; } } *(LoadT*)(q_head_ptr + offset_low) = *(LoadT*)r_q_low; *(LoadT*)(q_head_ptr + offset_high) = *(LoadT*)r_q_high; } const int total_threads_needed = (num_heads + num_kv_heads) * THREAD_PER_HEAD; if (lane >= q_boundary && lane < total_threads_needed && key != nullptr) { const int k_lane_abs = lane - q_boundary; const int kv_head_idx = k_lane_abs / THREAD_PER_HEAD; const int k_lane_in_head = k_lane_abs % THREAD_PER_HEAD; scalar_t* k_head_ptr = key + global_token_idx * key_stride + kv_head_idx * head_stride_k; scalar_t* res_k_head_ptr = HAS_RESIDUAL ? (residual_k + global_token_idx * key_stride + kv_head_idx * head_stride_k) : nullptr; using LoadTK = at::native::memory::aligned_vector; scalar_t r_k_low[VEC_SIZE]; scalar_t r_k_high[VEC_SIZE]; const int offset_low = k_lane_in_head * VEC_SIZE; const int offset_high = HALF_ROT + k_lane_in_head * VEC_SIZE; *(LoadTK*)r_k_low = *(LoadTK*)(k_head_ptr + offset_low); *(LoadTK*)r_k_high = *(LoadTK*)(k_head_ptr + offset_high); T_ACC sum_sq_k = apply_residual_and_calc_sq( r_k_low, r_k_high, res_k_head_ptr, offset_low, offset_high); sum_sq_k = warp_reduce_sum_xor(sum_sq_k); const T_ACC inv_rms_k = c10::cuda::compat::rsqrt(sum_sq_k / HEAD_SIZE + static_cast(eps)); const scalar_t* cache_ptr_k = my_s_cos_sin; if constexpr (IS_NEOX) { scalar_t r_cos_low[VEC_SIZE], r_sin_low[VEC_SIZE]; scalar_t r_gamma_k_low[VEC_SIZE], r_gamma_k_high[VEC_SIZE]; *(LoadTK*)r_cos_low = *(LoadTK*)(cache_ptr_k + offset_low); *(LoadTK*)r_sin_low = *(LoadTK*)(cache_ptr_k + rot_dim / 2 + offset_low); *(LoadTK*)r_gamma_k_low = *(LoadTK*)(gamma_k + offset_low); *(LoadTK*)r_gamma_k_high = *(LoadTK*)(gamma_k + offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { r_k_low[i] = static_cast(r_k_low[i]) * inv_rms_k * static_cast(r_gamma_k_low[i]); r_k_high[i] = static_cast(r_k_high[i]) * inv_rms_k * static_cast(r_gamma_k_high[i]); const scalar_t k_l = r_k_low[i]; const scalar_t k_h = r_k_high[i]; const scalar_t c = r_cos_low[i]; const scalar_t s = r_sin_low[i]; r_k_low[i] = k_l * c - k_h * s; r_k_high[i] = k_l * s + k_h * c; } } else { using LoadCacheTK = at::native::memory::aligned_vector; scalar_t r_cos_low[VEC_SIZE / 2], r_sin_low[VEC_SIZE / 2]; scalar_t r_cos_high[VEC_SIZE / 2], r_sin_high[VEC_SIZE / 2]; const int cache_offset_low = offset_low / 2; const int cache_offset_high = offset_high / 2; *(LoadCacheTK*)r_cos_low = *(LoadCacheTK*)(cache_ptr_k + cache_offset_low); *(LoadCacheTK*)r_sin_low = *(LoadCacheTK*)(cache_ptr_k + rot_dim / 2 + cache_offset_low); *(LoadCacheTK*)r_cos_high = *(LoadCacheTK*)(cache_ptr_k + cache_offset_high); *(LoadCacheTK*)r_sin_high = *(LoadCacheTK*)(cache_ptr_k + rot_dim / 2 + cache_offset_high); #pragma unroll for (int i = 0; i < VEC_SIZE; i += 2) { const int c_idx = i / 2; r_k_low[i] = static_cast(r_k_low[i]) * inv_rms_k * static_cast(gamma_k[offset_low + i]); r_k_low[i + 1] = static_cast(r_k_low[i + 1]) * inv_rms_k * static_cast(gamma_k[offset_low + i + 1]); const scalar_t k0 = r_k_low[i]; const scalar_t k1 = r_k_low[i + 1]; const scalar_t c = r_cos_low[c_idx]; const scalar_t s = r_sin_low[c_idx]; r_k_low[i] = k0 * c - k1 * s; r_k_low[i + 1] = k1 * c + k0 * s; r_k_high[i] = static_cast(r_k_high[i]) * inv_rms_k * static_cast(gamma_k[offset_high + i]); r_k_high[i + 1] = static_cast(r_k_high[i + 1]) * inv_rms_k * static_cast(gamma_k[offset_high + i + 1]); const scalar_t kh0 = r_k_high[i]; const scalar_t kh1 = r_k_high[i + 1]; const scalar_t ch = r_cos_high[c_idx]; const scalar_t sh = r_sin_high[c_idx]; r_k_high[i] = kh0 * ch - kh1 * sh; r_k_high[i + 1] = kh1 * ch + kh0 * sh; } } *(LoadTK*)(k_head_ptr + offset_low) = *(LoadTK*)r_k_low; *(LoadTK*)(k_head_ptr + offset_high) = *(LoadTK*)r_k_high; } } template __device__ __forceinline__ T_ACC apply_residual_and_calc_sq_4vec( scalar_t* v0, scalar_t* v1, scalar_t* v2, scalar_t* v3, scalar_t* res_ptr, const int o0, const int o1, const int o2, const int o3) { T_ACC local_sum = 0; #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { if constexpr (HAS_RESIDUAL) { const scalar_t r0 = res_ptr[o0 + i] + v0[i]; const scalar_t r1 = res_ptr[o1 + i] + v1[i]; const scalar_t r2 = res_ptr[o2 + i] + v2[i]; const scalar_t r3 = res_ptr[o3 + i] + v3[i]; res_ptr[o0 + i] = r0; res_ptr[o1 + i] = r1; res_ptr[o2 + i] = r2; res_ptr[o3 + i] = r3; v0[i] = r0; v1[i] = r1; v2[i] = r2; v3[i] = r3; } local_sum += static_cast(v0[i]) * static_cast(v0[i]); local_sum += static_cast(v1[i]) * static_cast(v1[i]); local_sum += static_cast(v2[i]) * static_cast(v2[i]); local_sum += static_cast(v3[i]) * static_cast(v3[i]); } return local_sum; } template __global__ void opt_rms_rope_qwen3_rot_dim64( const int64_t* __restrict__ positions, scalar_t* __restrict__ query, scalar_t* __restrict__ key, const scalar_t* __restrict__ cos_sin_cache, const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride_q, const int64_t head_stride_k, const scalar_t* __restrict__ gamma_q, const scalar_t* __restrict__ gamma_k, scalar_t* residual_q, scalar_t* residual_k, const scalar_t eps, const int num_tokens, const int num_heads, const int num_kv_heads, const int threads_per_token, const int tokens_per_block) { extern __shared__ char smem_buffer[]; scalar_t* s_cos_sin_base = reinterpret_cast(smem_buffer); constexpr int HEAD_SIZE = 128; const int tid = threadIdx.x; const int local_token_idx = tid / threads_per_token; const int lane = tid % threads_per_token; if (local_token_idx >= tokens_per_block) return; const int global_token_idx = blockIdx.x * tokens_per_block + local_token_idx; if (global_token_idx >= num_tokens) return; scalar_t* my_s_cos_sin = s_cos_sin_base + local_token_idx * rot_dim; const int64_t pos = positions[global_token_idx]; for (int i = lane; i < rot_dim; i += threads_per_token) { my_s_cos_sin[i] = cos_sin_cache[pos * rot_dim + i]; } __syncthreads(); const int q_boundary = num_heads * THREAD_PER_HEAD; const int total_threads = (num_heads + num_kv_heads) * THREAD_PER_HEAD; if (lane < total_threads) { const bool is_query = lane < q_boundary; const int head_idx = is_query ? (lane / THREAD_PER_HEAD) : ((lane - q_boundary) / THREAD_PER_HEAD); const int lane_in_head = is_query ? (lane % THREAD_PER_HEAD) : ((lane - q_boundary) % THREAD_PER_HEAD); scalar_t* head_ptr = is_query ? (query + global_token_idx * query_stride + head_idx * head_stride_q) : (key + global_token_idx * key_stride + head_idx * head_stride_k); scalar_t* res_head_ptr = nullptr; if constexpr (HAS_RESIDUAL) { res_head_ptr = is_query ? (residual_q + global_token_idx * query_stride + head_idx * head_stride_q) : (residual_k + global_token_idx * key_stride + head_idx * head_stride_k); } const scalar_t* gamma_ptr = is_query ? gamma_q : gamma_k; const int o0 = lane_in_head * VEC_SIZE; const int o1 = o0 + 32; const int o2 = o0 + 64; const int o3 = o0 + 96; using LoadT = at::native::memory::aligned_vector; scalar_t v0[VEC_SIZE], v1[VEC_SIZE], v2[VEC_SIZE], v3[VEC_SIZE]; *(LoadT*)v0 = *(LoadT*)(head_ptr + o0); *(LoadT*)v1 = *(LoadT*)(head_ptr + o1); *(LoadT*)v2 = *(LoadT*)(head_ptr + o2); *(LoadT*)v3 = *(LoadT*)(head_ptr + o3); T_ACC sum_sq = apply_residual_and_calc_sq_4vec( v0, v1, v2, v3, res_head_ptr, o0, o1, o2, o3); sum_sq = warp_reduce_sum_xor(sum_sq); const T_ACC inv_rms = c10::cuda::compat::rsqrt(sum_sq / HEAD_SIZE + static_cast(eps)); if constexpr (IS_NEOX) { scalar_t r_cos[VEC_SIZE], r_sin[VEC_SIZE]; *(LoadT*)r_cos = *(LoadT*)(my_s_cos_sin + o0); *(LoadT*)r_sin = *(LoadT*)(my_s_cos_sin + 32 + o0); #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { const T_ACC s0 = static_cast(v0[i]) * inv_rms * static_cast(gamma_ptr[o0 + i]); const T_ACC s1 = static_cast(v1[i]) * inv_rms * static_cast(gamma_ptr[o1 + i]); const T_ACC s2 = static_cast(v2[i]) * inv_rms * static_cast(gamma_ptr[o2 + i]); const T_ACC s3 = static_cast(v3[i]) * inv_rms * static_cast(gamma_ptr[o3 + i]); v0[i] = s0 * r_cos[i] - s1 * r_sin[i]; v1[i] = s0 * r_sin[i] + s1 * r_cos[i]; v2[i] = s2; v3[i] = s3; } } else { #pragma unroll for (int i = 0; i < VEC_SIZE; i += 2) { const int idx_c0 = (o0 + i) / 2; const scalar_t cos0 = my_s_cos_sin[idx_c0]; const scalar_t sin0 = my_s_cos_sin[32 + idx_c0]; const T_ACC s0_0 = static_cast(v0[i]) * inv_rms * static_cast(gamma_ptr[o0 + i]); const T_ACC s0_1 = static_cast(v0[i + 1]) * inv_rms * static_cast(gamma_ptr[o0 + i + 1]); v0[i] = s0_0 * cos0 - s0_1 * sin0; v0[i + 1] = s0_1 * cos0 + s0_0 * sin0; const int idx_c1 = (o1 + i) / 2; const scalar_t cos1 = my_s_cos_sin[idx_c1]; const scalar_t sin1 = my_s_cos_sin[32 + idx_c1]; const T_ACC s1_0 = static_cast(v1[i]) * inv_rms * static_cast(gamma_ptr[o1 + i]); const T_ACC s1_1 = static_cast(v1[i + 1]) * inv_rms * static_cast(gamma_ptr[o1 + i + 1]); v1[i] = s1_0 * cos1 - s1_1 * sin1; v1[i + 1] = s1_1 * cos1 + s1_0 * sin1; } #pragma unroll for (int i = 0; i < VEC_SIZE; ++i) { v2[i] = static_cast(v2[i]) * inv_rms * static_cast(gamma_ptr[o2 + i]); v3[i] = static_cast(v3[i]) * inv_rms * static_cast(gamma_ptr[o3 + i]); } } *(LoadT*)(head_ptr + o0) = *(LoadT*)v0; *(LoadT*)(head_ptr + o1) = *(LoadT*)v1; *(LoadT*)(head_ptr + o2) = *(LoadT*)v2; *(LoadT*)(head_ptr + o3) = *(LoadT*)v3; } } template void launch_opt_rms_rope( const int64_t* positions, scalar_t* query, scalar_t* key, const scalar_t* cos_sin_cache, const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int64_t head_stride_q, const int64_t head_stride_k, const scalar_t* gamma_q, const scalar_t* gamma_k, scalar_t* residual_q_ptr, scalar_t* residual_k_ptr, const scalar_t eps, const int num_tokens, const bool is_neox, const int num_heads, const int num_kv_heads, const cudaStream_t stream) { const bool has_residual = (residual_q_ptr != nullptr && residual_k_ptr != nullptr); constexpr int THREAD_PER_HEAD = 8; constexpr int VEC = 8; const int threads_per_token = (num_heads + num_kv_heads) * THREAD_PER_HEAD; const int target_block_size = 512; int tokens_per_block = target_block_size / threads_per_token; if (tokens_per_block < 1) tokens_per_block = 1; const int actual_block_size = tokens_per_block * threads_per_token; const int grid_size = (num_tokens + tokens_per_block - 1) / tokens_per_block; if (rot_dim == 128) { const size_t smem_size = tokens_per_block * 128 * sizeof(scalar_t); DISPATCH_BOOL(has_residual, HAS_RESIDUAL_CONST, [&] { DISPATCH_BOOL(is_neox, IS_NEOX_CONST, [&] { opt_rms_rope_qwen3 <<>>( positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, head_stride_q, head_stride_k, gamma_q, gamma_k, residual_q_ptr, residual_k_ptr, eps, num_tokens, num_heads, num_kv_heads, threads_per_token, tokens_per_block); }); }); return; } if (rot_dim == 64) { const size_t smem_size = tokens_per_block * 64 * sizeof(scalar_t); DISPATCH_BOOL(has_residual, HAS_RESIDUAL_CONST, [&] { DISPATCH_BOOL(is_neox, IS_NEOX_CONST, [&] { opt_rms_rope_qwen3_rot_dim64 <<>>( positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, head_stride_q, head_stride_k, gamma_q, gamma_k, residual_q_ptr, residual_k_ptr, eps, num_tokens, num_heads, num_kv_heads, threads_per_token, tokens_per_block); }); }); return; } } } // namespace vllm void rms_rotary_embedding_fuse( torch::Tensor& positions, torch::Tensor& query, std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox, torch::Tensor& weight_q, torch::Tensor& weight_k, std::optional residual_q, std::optional residual_k, double epsilon) { // Basic validation (mirrors rotary_embedding + layernorm checks). const int64_t num_tokens = positions.numel(); const int positions_ndim = positions.dim(); TORCH_CHECK(positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { TORCH_CHECK(query.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)), "query, key and positions must have the same number of tokens"); } else { TORCH_CHECK( query.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } TORCH_CHECK(query.is_cuda(), "query must be CUDA"); TORCH_CHECK(!key.has_value() || key->is_cuda(), "key must be CUDA"); TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); TORCH_CHECK(positions.is_cuda(), "positions must be CUDA"); TORCH_CHECK(weight_q.is_cuda() && weight_k.is_cuda(), "weights must be CUDA"); TORCH_CHECK(cos_sin_cache.is_contiguous(), "cos_sin_cache must be contiguous"); TORCH_CHECK(weight_q.is_contiguous() && weight_k.is_contiguous(), "weights must be contiguous"); TORCH_CHECK(positions.scalar_type() == at::kLong, "positions must be int64"); TORCH_CHECK(query.scalar_type() == cos_sin_cache.scalar_type(), "cos_sin_cache must have same dtype as query"); TORCH_CHECK(weight_q.scalar_type() == query.scalar_type() && weight_k.scalar_type() == query.scalar_type(), "weights must have same dtype as query"); TORCH_CHECK(!key.has_value() || key->scalar_type() == query.scalar_type(), "key must have same dtype as query"); if (residual_q.has_value() || residual_k.has_value()) { TORCH_CHECK(residual_q.has_value() && residual_k.has_value(), "residual_q and residual_k must be both provided or both None"); TORCH_CHECK(residual_q->is_cuda() && residual_k->is_cuda(), "residual tensors must be CUDA"); TORCH_CHECK(residual_q->scalar_type() == query.scalar_type() && residual_k->scalar_type() == query.scalar_type(), "residual tensors must have same dtype as query"); } const bool has_residual = residual_q.has_value() && residual_k.has_value(); const bool query_needs_copy_back = !query.is_contiguous(); torch::Tensor query_work = query_needs_copy_back ? query.contiguous() : query; std::optional key_work = std::nullopt; bool key_needs_copy_back = false; if (key.has_value()) { key_needs_copy_back = !key->is_contiguous(); key_work = key_needs_copy_back ? key->contiguous() : *key; } std::optional residual_q_work = std::nullopt; std::optional residual_k_work = std::nullopt; bool residual_q_needs_copy_back = false; bool residual_k_needs_copy_back = false; if (has_residual) { residual_q_needs_copy_back = !residual_q->is_contiguous(); residual_k_needs_copy_back = !residual_k->is_contiguous(); residual_q_work = residual_q_needs_copy_back ? residual_q->contiguous() : *residual_q; residual_k_work = residual_k_needs_copy_back ? residual_k->contiguous() : *residual_k; } const int query_hidden_size = query_work.numel() / num_tokens; const int key_hidden_size = key_work.has_value() ? key_work->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(!key_work.has_value() || (key_hidden_size % head_size == 0)); const int num_heads = query_hidden_size / head_size; const int num_kv_heads = key_work.has_value() ? (key_hidden_size / head_size) : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); const int rot_dim = cos_sin_cache.size(1); const int seq_dim_idx = positions_ndim - 1; const int64_t query_stride = query_work.stride(seq_dim_idx); const int64_t key_stride = key_work.has_value() ? key_work->stride(seq_dim_idx) : 0; const int query_ndim = query_work.dim(); const int64_t head_stride_q = (query_ndim == positions_ndim + 2) ? query_work.stride(-2) : head_size; const int64_t head_stride_k = (key_work.has_value() && key_work->dim() == positions_ndim + 2) ? key_work->stride(-2) : head_size; const bool supports_qwen3_opt = (key_work.has_value() && head_size == 128 && (rot_dim == 128 || rot_dim == 64) && (num_heads + num_kv_heads) <= 128 && weight_q.numel() == 128 && weight_k.numel() == 128); const at::cuda::OptionalCUDAGuard device_guard(device_of(query_work)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); bool used_qwen3_opt = false; if (supports_qwen3_opt) { AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, query_work.scalar_type(), "vllm_rms_rotary_embedding_fuse_qwen3", [&] { using T_ACC = at::acc_type; scalar_t* res_q_ptr = has_residual ? residual_q_work->data_ptr() : nullptr; scalar_t* res_k_ptr = has_residual ? residual_k_work->data_ptr() : nullptr; vllm::launch_opt_rms_rope( positions.data_ptr(), query_work.data_ptr(), key_work->data_ptr(), cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, head_stride_q, head_stride_k, weight_q.data_ptr(), weight_k.data_ptr(), res_q_ptr, res_k_ptr, static_cast(epsilon), static_cast(num_tokens), is_neox, num_heads, num_kv_heads, stream); }); used_qwen3_opt = true; } if (!used_qwen3_opt) { // Fallback: use existing kernels (still removes lightop dependency). // Apply per-head RMSNorm to Q/K and then call the existing RoPE kernel. TORCH_CHECK(weight_q.numel() == head_size && weight_k.numel() == head_size, "weight_q/weight_k must have shape [head_size]"); auto q_heads = query_work.view({num_tokens * num_heads, head_size}); if (has_residual) { auto rq_heads = residual_q_work->view({num_tokens * num_heads, head_size}); fused_add_rms_norm_opt(q_heads, rq_heads, weight_q, epsilon); } else { rms_norm_opt(q_heads, q_heads, weight_q, epsilon); } if (key_work.has_value()) { auto k_heads = key_work->view({num_tokens * num_kv_heads, head_size}); if (has_residual) { auto rk_heads = residual_k_work->view({num_tokens * num_kv_heads, head_size}); fused_add_rms_norm_opt(k_heads, rk_heads, weight_k, epsilon); } else { rms_norm_opt(k_heads, k_heads, weight_k, epsilon); } } rotary_embedding(positions, query_work, key_work, head_size, cos_sin_cache, is_neox); } if (query_needs_copy_back) { query.copy_(query_work); } if (key_needs_copy_back) { key->copy_(*key_work); } if (residual_q_needs_copy_back) { residual_q->copy_(*residual_q_work); } if (residual_k_needs_copy_back) { residual_k->copy_(*residual_k_work); } }