Commit b65d0556 authored by chenyue3's avatar chenyue3
Browse files

fix(qwen3): 在 fused RMS+RoPE 算子内支持非连续输入

  - 在 C++ 算子中移除 query/key/residual 的 contiguous 强校验
  - 对非连续输入使用工作张量执行计算(优化路径与 fallback 路径统一)
  - 计算完成后将结果 copy_ 回原张量,保持 in-place 语义
  - 移除 qwen3 / qwen3_moe Python 前向中的 q、k.contiguous() 预处理
parent 76687ddd
......@@ -552,9 +552,6 @@ void rms_rotary_embedding_fuse(
TORCH_CHECK(positions.is_cuda(), "positions must be CUDA");
TORCH_CHECK(weight_q.is_cuda() && weight_k.is_cuda(),
"weights must be CUDA");
TORCH_CHECK(query.is_contiguous(), "query must be contiguous");
TORCH_CHECK(!key.has_value() || key->is_contiguous(),
"key must be contiguous");
TORCH_CHECK(cos_sin_cache.is_contiguous(), "cos_sin_cache must be contiguous");
TORCH_CHECK(weight_q.is_contiguous() && weight_k.is_contiguous(),
"weights must be contiguous");
......@@ -573,87 +570,129 @@ void rms_rotary_embedding_fuse(
"residual_q and residual_k must be both provided or both None");
TORCH_CHECK(residual_q->is_cuda() && residual_k->is_cuda(),
"residual tensors must be CUDA");
TORCH_CHECK(residual_q->is_contiguous() && residual_k->is_contiguous(),
"residual tensors must be contiguous");
TORCH_CHECK(residual_q->scalar_type() == query.scalar_type() &&
residual_k->scalar_type() == query.scalar_type(),
"residual tensors must have same dtype as query");
}
const int query_hidden_size = query.numel() / num_tokens;
const int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
const bool has_residual = residual_q.has_value() && residual_k.has_value();
const bool query_needs_copy_back = !query.is_contiguous();
torch::Tensor query_work =
query_needs_copy_back ? query.contiguous() : query;
std::optional<torch::Tensor> key_work = std::nullopt;
bool key_needs_copy_back = false;
if (key.has_value()) {
key_needs_copy_back = !key->is_contiguous();
key_work = key_needs_copy_back ? key->contiguous() : *key;
}
std::optional<torch::Tensor> residual_q_work = std::nullopt;
std::optional<torch::Tensor> residual_k_work = std::nullopt;
bool residual_q_needs_copy_back = false;
bool residual_k_needs_copy_back = false;
if (has_residual) {
residual_q_needs_copy_back = !residual_q->is_contiguous();
residual_k_needs_copy_back = !residual_k->is_contiguous();
residual_q_work =
residual_q_needs_copy_back ? residual_q->contiguous() : *residual_q;
residual_k_work =
residual_k_needs_copy_back ? residual_k->contiguous() : *residual_k;
}
const int query_hidden_size = query_work.numel() / num_tokens;
const int key_hidden_size =
key_work.has_value() ? key_work->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(!key.has_value() || (key_hidden_size % head_size == 0));
TORCH_CHECK(!key_work.has_value() || (key_hidden_size % head_size == 0));
const int num_heads = query_hidden_size / head_size;
const int num_kv_heads = key.has_value() ? (key_hidden_size / head_size)
: num_heads;
const int num_kv_heads =
key_work.has_value() ? (key_hidden_size / head_size) : num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0);
const int rot_dim = cos_sin_cache.size(1);
const int seq_dim_idx = positions_ndim - 1;
const int64_t query_stride = query.stride(seq_dim_idx);
const int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
const int query_ndim = query.dim();
const int64_t query_stride = query_work.stride(seq_dim_idx);
const int64_t key_stride =
key_work.has_value() ? key_work->stride(seq_dim_idx) : 0;
const int query_ndim = query_work.dim();
const int64_t head_stride_q =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
(query_ndim == positions_ndim + 2) ? query_work.stride(-2) : head_size;
const int64_t head_stride_k =
(key.has_value() && key->dim() == positions_ndim + 2) ? key->stride(-2)
: head_size;
(key_work.has_value() && key_work->dim() == positions_ndim + 2)
? key_work->stride(-2)
: head_size;
const bool has_residual = residual_q.has_value() && residual_k.has_value();
const bool supports_qwen3_opt =
(key.has_value() && head_size == 128 && (rot_dim == 128 || rot_dim == 64) &&
(key_work.has_value() && head_size == 128 &&
(rot_dim == 128 || rot_dim == 64) &&
(num_heads + num_kv_heads) <= 128 && weight_q.numel() == 128 &&
weight_k.numel() == 128);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query_work));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
bool used_qwen3_opt = false;
if (supports_qwen3_opt) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, query.scalar_type(),
at::ScalarType::Half, at::ScalarType::BFloat16, query_work.scalar_type(),
"vllm_rms_rotary_embedding_fuse_qwen3", [&] {
using T_ACC = at::acc_type<scalar_t, true>;
scalar_t* res_q_ptr =
has_residual ? residual_q->data_ptr<scalar_t>() : nullptr;
has_residual ? residual_q_work->data_ptr<scalar_t>() : nullptr;
scalar_t* res_k_ptr =
has_residual ? residual_k->data_ptr<scalar_t>() : nullptr;
has_residual ? residual_k_work->data_ptr<scalar_t>() : nullptr;
vllm::launch_opt_rms_rope<T_ACC, scalar_t>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key->data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(), query_work.data_ptr<scalar_t>(),
key_work->data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, head_stride_q, head_stride_k,
weight_q.data_ptr<scalar_t>(), weight_k.data_ptr<scalar_t>(),
res_q_ptr, res_k_ptr, static_cast<scalar_t>(epsilon),
static_cast<int>(num_tokens), is_neox, num_heads, num_kv_heads,
stream);
});
return;
used_qwen3_opt = true;
}
// Fallback: use existing kernels (still removes lightop dependency).
// Apply per-head RMSNorm to Q/K and then call the existing RoPE kernel.
{
if (!used_qwen3_opt) {
// Fallback: use existing kernels (still removes lightop dependency).
// Apply per-head RMSNorm to Q/K and then call the existing RoPE kernel.
TORCH_CHECK(weight_q.numel() == head_size && weight_k.numel() == head_size,
"weight_q/weight_k must have shape [head_size]");
auto q_heads = query.view({num_tokens * num_heads, head_size});
auto q_heads = query_work.view({num_tokens * num_heads, head_size});
if (has_residual) {
auto rq_heads = residual_q->view({num_tokens * num_heads, head_size});
auto rq_heads = residual_q_work->view({num_tokens * num_heads, head_size});
fused_add_rms_norm_opt(q_heads, rq_heads, weight_q, epsilon);
} else {
rms_norm_opt(q_heads, q_heads, weight_q, epsilon);
}
if (key.has_value()) {
auto k_heads = key->view({num_tokens * num_kv_heads, head_size});
if (key_work.has_value()) {
auto k_heads = key_work->view({num_tokens * num_kv_heads, head_size});
if (has_residual) {
auto rk_heads =
residual_k->view({num_tokens * num_kv_heads, head_size});
residual_k_work->view({num_tokens * num_kv_heads, head_size});
fused_add_rms_norm_opt(k_heads, rk_heads, weight_k, epsilon);
} else {
rms_norm_opt(k_heads, k_heads, weight_k, epsilon);
}
}
rotary_embedding(positions, query_work, key_work, head_size, cos_sin_cache,
is_neox);
}
if (query_needs_copy_back) {
query.copy_(query_work);
}
if (key_needs_copy_back) {
key->copy_(*key_work);
}
if (residual_q_needs_copy_back) {
residual_q->copy_(*residual_q_work);
}
if (residual_k_needs_copy_back) {
residual_k->copy_(*residual_k_work);
}
rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox);
}
......@@ -256,8 +256,6 @@ class Qwen3Attention(nn.Module):
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
......
......@@ -451,9 +451,6 @@ class Qwen3MoeAttention(nn.Module):
# Persist the converted cache so we don't re-copy/re-allocate
# on every forward when the original buffer starts on CPU.
self.rotary_emb.cos_sin_cache = cos_sin_cache
# # q, k 使用 continuous
q = q.contiguous()
k = k.contiguous()
torch.ops.vllm.rms_rotary_embedding_fuse(
positions,
q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment