Commit 9bf8406f authored by chenyue3's avatar chenyue3
Browse files

feat(qwen3): 支持 rot_dim=64 的 fused RMS+RoPE 优化路径

  launch_opt_rms_rope 增加 rot_dim=64/128 分发路径
  放宽 supports_qwen3_opt 条件,允许 rot_dim=64 进入优化分支
  qwen3 / qwen3_moe 将 q_bias、k_bias 参数统一为 q_residual、k_residual
  qwen3_moe 的自定义 op 注册增加重复注册保护"`
parent 588538f5
......@@ -297,6 +297,174 @@ __global__ void opt_rms_rope_qwen3(
}
}
template <typename T_ACC, typename scalar_t, int VEC_SIZE, bool HAS_RESIDUAL>
__device__ __forceinline__ T_ACC apply_residual_and_calc_sq_4vec(
scalar_t* v0, scalar_t* v1, scalar_t* v2, scalar_t* v3,
scalar_t* res_ptr, const int o0, const int o1, const int o2, const int o3) {
T_ACC local_sum = 0;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
if constexpr (HAS_RESIDUAL) {
const scalar_t r0 = res_ptr[o0 + i] + v0[i];
const scalar_t r1 = res_ptr[o1 + i] + v1[i];
const scalar_t r2 = res_ptr[o2 + i] + v2[i];
const scalar_t r3 = res_ptr[o3 + i] + v3[i];
res_ptr[o0 + i] = r0;
res_ptr[o1 + i] = r1;
res_ptr[o2 + i] = r2;
res_ptr[o3 + i] = r3;
v0[i] = r0;
v1[i] = r1;
v2[i] = r2;
v3[i] = r3;
}
local_sum += static_cast<T_ACC>(v0[i]) * static_cast<T_ACC>(v0[i]);
local_sum += static_cast<T_ACC>(v1[i]) * static_cast<T_ACC>(v1[i]);
local_sum += static_cast<T_ACC>(v2[i]) * static_cast<T_ACC>(v2[i]);
local_sum += static_cast<T_ACC>(v3[i]) * static_cast<T_ACC>(v3[i]);
}
return local_sum;
}
template <typename T_ACC, typename scalar_t, bool HAS_RESIDUAL, bool IS_NEOX,
int VEC_SIZE, int THREAD_PER_HEAD>
__global__ void opt_rms_rope_qwen3_rot_dim64(
const int64_t* __restrict__ positions, scalar_t* __restrict__ query,
scalar_t* __restrict__ key, const scalar_t* __restrict__ cos_sin_cache,
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride_q, const int64_t head_stride_k,
const scalar_t* __restrict__ gamma_q,
const scalar_t* __restrict__ gamma_k, scalar_t* residual_q,
scalar_t* residual_k, const scalar_t eps, const int num_tokens,
const int num_heads, const int num_kv_heads, const int threads_per_token,
const int tokens_per_block) {
extern __shared__ char smem_buffer[];
scalar_t* s_cos_sin_base = reinterpret_cast<scalar_t*>(smem_buffer);
constexpr int HEAD_SIZE = 128;
const int tid = threadIdx.x;
const int local_token_idx = tid / threads_per_token;
const int lane = tid % threads_per_token;
if (local_token_idx >= tokens_per_block) return;
const int global_token_idx = blockIdx.x * tokens_per_block + local_token_idx;
if (global_token_idx >= num_tokens) return;
scalar_t* my_s_cos_sin = s_cos_sin_base + local_token_idx * rot_dim;
const int64_t pos = positions[global_token_idx];
for (int i = lane; i < rot_dim; i += threads_per_token) {
my_s_cos_sin[i] = cos_sin_cache[pos * rot_dim + i];
}
__syncthreads();
const int q_boundary = num_heads * THREAD_PER_HEAD;
const int total_threads = (num_heads + num_kv_heads) * THREAD_PER_HEAD;
if (lane < total_threads) {
const bool is_query = lane < q_boundary;
const int head_idx =
is_query ? (lane / THREAD_PER_HEAD) : ((lane - q_boundary) / THREAD_PER_HEAD);
const int lane_in_head =
is_query ? (lane % THREAD_PER_HEAD) : ((lane - q_boundary) % THREAD_PER_HEAD);
scalar_t* head_ptr = is_query
? (query + global_token_idx * query_stride +
head_idx * head_stride_q)
: (key + global_token_idx * key_stride +
head_idx * head_stride_k);
scalar_t* res_head_ptr = nullptr;
if constexpr (HAS_RESIDUAL) {
res_head_ptr = is_query
? (residual_q + global_token_idx * query_stride +
head_idx * head_stride_q)
: (residual_k + global_token_idx * key_stride +
head_idx * head_stride_k);
}
const scalar_t* gamma_ptr = is_query ? gamma_q : gamma_k;
const int o0 = lane_in_head * VEC_SIZE;
const int o1 = o0 + 32;
const int o2 = o0 + 64;
const int o3 = o0 + 96;
using LoadT = at::native::memory::aligned_vector<scalar_t, VEC_SIZE>;
scalar_t v0[VEC_SIZE], v1[VEC_SIZE], v2[VEC_SIZE], v3[VEC_SIZE];
*(LoadT*)v0 = *(LoadT*)(head_ptr + o0);
*(LoadT*)v1 = *(LoadT*)(head_ptr + o1);
*(LoadT*)v2 = *(LoadT*)(head_ptr + o2);
*(LoadT*)v3 = *(LoadT*)(head_ptr + o3);
T_ACC sum_sq =
apply_residual_and_calc_sq_4vec<T_ACC, scalar_t, VEC_SIZE, HAS_RESIDUAL>(
v0, v1, v2, v3, res_head_ptr, o0, o1, o2, o3);
sum_sq = warp_reduce_sum_xor<T_ACC, THREAD_PER_HEAD>(sum_sq);
const T_ACC inv_rms =
c10::cuda::compat::rsqrt(sum_sq / HEAD_SIZE + static_cast<T_ACC>(eps));
if constexpr (IS_NEOX) {
scalar_t r_cos[VEC_SIZE], r_sin[VEC_SIZE];
*(LoadT*)r_cos = *(LoadT*)(my_s_cos_sin + o0);
*(LoadT*)r_sin = *(LoadT*)(my_s_cos_sin + 32 + o0);
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
const T_ACC s0 = static_cast<T_ACC>(v0[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o0 + i]);
const T_ACC s1 = static_cast<T_ACC>(v1[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o1 + i]);
const T_ACC s2 = static_cast<T_ACC>(v2[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o2 + i]);
const T_ACC s3 = static_cast<T_ACC>(v3[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o3 + i]);
v0[i] = s0 * r_cos[i] - s1 * r_sin[i];
v1[i] = s0 * r_sin[i] + s1 * r_cos[i];
v2[i] = s2;
v3[i] = s3;
}
} else {
#pragma unroll
for (int i = 0; i < VEC_SIZE; i += 2) {
const int idx_c0 = (o0 + i) / 2;
const scalar_t cos0 = my_s_cos_sin[idx_c0];
const scalar_t sin0 = my_s_cos_sin[32 + idx_c0];
const T_ACC s0_0 = static_cast<T_ACC>(v0[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o0 + i]);
const T_ACC s0_1 = static_cast<T_ACC>(v0[i + 1]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o0 + i + 1]);
v0[i] = s0_0 * cos0 - s0_1 * sin0;
v0[i + 1] = s0_1 * cos0 + s0_0 * sin0;
const int idx_c1 = (o1 + i) / 2;
const scalar_t cos1 = my_s_cos_sin[idx_c1];
const scalar_t sin1 = my_s_cos_sin[32 + idx_c1];
const T_ACC s1_0 = static_cast<T_ACC>(v1[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o1 + i]);
const T_ACC s1_1 = static_cast<T_ACC>(v1[i + 1]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o1 + i + 1]);
v1[i] = s1_0 * cos1 - s1_1 * sin1;
v1[i + 1] = s1_1 * cos1 + s1_0 * sin1;
}
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
v2[i] = static_cast<T_ACC>(v2[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o2 + i]);
v3[i] = static_cast<T_ACC>(v3[i]) * inv_rms *
static_cast<T_ACC>(gamma_ptr[o3 + i]);
}
}
*(LoadT*)(head_ptr + o0) = *(LoadT*)v0;
*(LoadT*)(head_ptr + o1) = *(LoadT*)v1;
*(LoadT*)(head_ptr + o2) = *(LoadT*)v2;
*(LoadT*)(head_ptr + o3) = *(LoadT*)v3;
}
}
template <typename T_ACC, typename scalar_t>
void launch_opt_rms_rope(
const int64_t* positions, scalar_t* query, scalar_t* key,
......@@ -313,14 +481,13 @@ void launch_opt_rms_rope(
constexpr int VEC = 8;
const int threads_per_token = (num_heads + num_kv_heads) * THREAD_PER_HEAD;
// Keep the same launch heuristic as the original kernel.
const int target_block_size = 256;
const int target_block_size = 512;
int tokens_per_block = target_block_size / threads_per_token;
if (tokens_per_block < 1) tokens_per_block = 1;
const int actual_block_size = tokens_per_block * threads_per_token;
const int grid_size = (num_tokens + tokens_per_block - 1) / tokens_per_block;
if (rot_dim == 128) {
const size_t smem_size = tokens_per_block * 128 * sizeof(scalar_t);
DISPATCH_BOOL(has_residual, HAS_RESIDUAL_CONST, [&] {
DISPATCH_BOOL(is_neox, IS_NEOX_CONST, [&] {
opt_rms_rope_qwen3<T_ACC, scalar_t, HAS_RESIDUAL_CONST, IS_NEOX_CONST,
......@@ -332,6 +499,24 @@ void launch_opt_rms_rope(
num_kv_heads, threads_per_token, tokens_per_block);
});
});
return;
}
if (rot_dim == 64) {
const size_t smem_size = tokens_per_block * 64 * sizeof(scalar_t);
DISPATCH_BOOL(has_residual, HAS_RESIDUAL_CONST, [&] {
DISPATCH_BOOL(is_neox, IS_NEOX_CONST, [&] {
opt_rms_rope_qwen3_rot_dim64<T_ACC, scalar_t, HAS_RESIDUAL_CONST,
IS_NEOX_CONST, 4, THREAD_PER_HEAD>
<<<grid_size, actual_block_size, smem_size, stream>>>(
positions, query, key, cos_sin_cache, rot_dim, query_stride,
key_stride, head_stride_q, head_stride_k, gamma_q, gamma_k,
residual_q_ptr, residual_k_ptr, eps, num_tokens, num_heads,
num_kv_heads, threads_per_token, tokens_per_block);
});
});
return;
}
}
} // namespace vllm
......@@ -418,7 +603,7 @@ void rms_rotary_embedding_fuse(
const bool has_residual = residual_q.has_value() && residual_k.has_value();
const bool supports_qwen3_opt =
(key.has_value() && head_size == 128 && rot_dim == 128 &&
(key.has_value() && head_size == 128 && (rot_dim == 128 || rot_dim == 64) &&
(num_heads + num_kv_heads) <= 128 && weight_q.numel() == 128 &&
weight_k.numel() == 128);
......
......@@ -148,8 +148,8 @@ class Qwen3Attention(nn.Module):
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
q_residual: Optional[torch.Tensor],
k_residual: Optional[torch.Tensor],
epsilon: float,
) -> None:
backend = envs.VLLM_FUSED_RMS_ROPE_BACKEND
......@@ -164,8 +164,8 @@ class Qwen3Attention(nn.Module):
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
q_residual,
k_residual,
epsilon,
)
return
......@@ -194,8 +194,8 @@ class Qwen3Attention(nn.Module):
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
q_residual,
k_residual,
epsilon,
)
return
......@@ -209,8 +209,8 @@ class Qwen3Attention(nn.Module):
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
q_residual,
k_residual,
epsilon,
)
......@@ -223,8 +223,8 @@ class Qwen3Attention(nn.Module):
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
q_residual: Optional[torch.Tensor],
k_residual: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
......
......@@ -280,8 +280,8 @@ class Qwen3MoeAttention(nn.Module):
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
q_residual: Optional[torch.Tensor],
k_residual: Optional[torch.Tensor],
epsilon: float,
) -> None:
backend = envs.VLLM_FUSED_RMS_ROPE_BACKEND
......@@ -296,8 +296,8 @@ class Qwen3MoeAttention(nn.Module):
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
q_residual,
k_residual,
epsilon,
)
return
......@@ -325,8 +325,8 @@ class Qwen3MoeAttention(nn.Module):
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
q_residual,
k_residual,
epsilon,
)
return
......@@ -340,8 +340,8 @@ class Qwen3MoeAttention(nn.Module):
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
q_residual,
k_residual,
epsilon,
)
......@@ -356,14 +356,14 @@ class Qwen3MoeAttention(nn.Module):
is_neox_style: bool,
q_weight: torch.Tensor,
k_weight: torch.Tensor,
q_bias: Optional[torch.Tensor],
k_bias: Optional[torch.Tensor],
q_residual: Optional[torch.Tensor],
k_residual: Optional[torch.Tensor],
epsilon: float,
) -> None:
# Fake impl intentionally left as no-op for graph tracing modes.
pass
if not hasattr(torch.ops.vllm, "rms_rotary_embedding_fuse"):
direct_register_custom_op(
op_name="rms_rotary_embedding_fuse",
op_func=rms_rotary_embedding_fuse,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment