Commit 588538f5 authored by laibao's avatar laibao
Browse files

• feat(qwen3):新增 vLLM 内置 RMS+RoPE 融合算子,并支持 LightOp 后端切换

  - 在 vLLM _C 扩展中新增 rms_rotary_embedding_fuse(注册 op + CUDA kernel),减少对 LightOp 的硬依赖
  - 新增环境变量 VLLM_FUSED_RMS_ROPE_BACKEND=auto|vllm|lightop,auto 优先走 vLLM,缺失时回退 LightOp
  - 更新 Qwen3 / Qwen3-MoE 的 fused 路径按后端选择执行
  - 补充 tc_opt benchmark 结果解析脚本 benchmarks/tc_opt/test/parse_bench_results.py
parent 70506d98
...@@ -256,6 +256,7 @@ set(VLLM_EXT_SRC ...@@ -256,6 +256,7 @@ set(VLLM_EXT_SRC
"csrc/attention/merge_attn_states.cu" "csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu" "csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/fuse_rms_rope_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu" "csrc/opt/transpose_kernels.cu"
...@@ -978,4 +979,4 @@ if (VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -978,4 +979,4 @@ if (VLLM_GPU_LANG STREQUAL "CUDA")
# vllm-flash-attn should be last as it overwrites some CMake functions # vllm-flash-attn should be last as it overwrites some CMake functions
include(cmake/external_projects/vllm_flash_attn.cmake) include(cmake/external_projects/vllm_flash_attn.cmake)
endif () endif ()
\ No newline at end of file
#!/usr/bin/env python3
import argparse
import csv
import io
import json
import re
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
_FILENAME_RE = re.compile(r"_bs(?P<bs>\d+)_in(?P<in_len>\d+)_out(?P<out_len>\d+)\.(?P<ext>log|txt)$")
_BLOCK_START = "============ Serving Benchmark Result ============"
_BLOCK_END = "=================================================="
def _try_parse_number(value):
v = value.strip()
if not v:
return v
m = re.match(r"^-?\d+$", v)
if m:
try:
return int(v)
except ValueError:
return v
m = re.match(r"^-?\d+(?:\.\d+)?$", v)
if m:
try:
return float(v)
except ValueError:
return v
return v
def _parse_serving_result_block(text):
lines = text.splitlines()
try:
start_idx = lines.index(_BLOCK_START)
except ValueError:
return {}
end_idx = None
for i in range(start_idx + 1, len(lines)):
if lines[i].strip() == _BLOCK_END:
end_idx = i
break
if end_idx is None:
end_idx = len(lines)
metrics = {}
for raw in lines[start_idx + 1 : end_idx]:
if ":" not in raw:
continue
key, value = raw.split(":", 1)
key = key.strip()
value = value.strip()
if not key or not value:
continue
# Values are padded; take the first token if it looks numeric.
first = value.split()[0]
metrics[key] = _try_parse_number(first)
return metrics
def _extract_case_from_path(path):
m = _FILENAME_RE.search(path.name)
if not m:
return (None, None, None)
return (int(m.group("bs")), int(m.group("in_len")), int(m.group("out_len")))
class BenchResult(object):
def __init__(self, path, bs, in_len, out_len, metrics):
self.path = path
self.bs = bs
self.in_len = in_len
self.out_len = out_len
self.metrics = metrics
def key(self):
bs = self.bs if self.bs is not None else 0
in_len = self.in_len if self.in_len is not None else 0
out_len = self.out_len if self.out_len is not None else 0
return (bs, in_len, out_len, self.path.name)
def _find_logs(paths):
logs = []
for p in paths:
if p.is_dir():
logs.extend(sorted(p.glob("bench_*.log")))
else:
logs.append(p)
return [p for p in logs if p.exists()]
def _fmt_float(v):
if isinstance(v, float):
return f"{v:.2f}"
if isinstance(v, int):
return str(v)
if v is None:
return "NA"
return str(v)
def _md_line(r):
m = r.metrics
return (
f"- bs={r.bs} in={r.in_len} out={r.out_len}: "
f"req/s={_fmt_float(m.get('Request throughput (req/s)'))}, "
f"out_tok/s={_fmt_float(m.get('Output token throughput (tok/s)'))}, "
f"TTFT mean/p99={_fmt_float(m.get('Mean TTFT (ms)'))}/{_fmt_float(m.get('P99 TTFT (ms)'))} ms, "
f"TPOT mean/p99={_fmt_float(m.get('Mean TPOT (ms)'))}/{_fmt_float(m.get('P99 TPOT (ms)'))} ms, "
f"ITL mean/p99={_fmt_float(m.get('Mean ITL (ms)'))}/{_fmt_float(m.get('P99 ITL (ms)'))} ms"
)
def main(argv):
parser = argparse.ArgumentParser(
description="Parse vLLM benchmark-serving bench_*.log files and output a request-result list."
)
parser.add_argument(
"paths",
nargs="+",
help="One or more bench_*.log files or a directory containing them.",
)
parser.add_argument(
"--format",
choices=("markdown", "csv", "jsonl"),
default="markdown",
help="Output format (default: markdown).",
)
parser.add_argument(
"--output",
help="Write output to a file instead of stdout.",
)
args = parser.parse_args(argv)
input_paths = [Path(p).expanduser() for p in args.paths]
logs = _find_logs(input_paths)
if not logs:
print("No bench_*.log files found.", file=sys.stderr)
return 2
results = []
for log in logs:
try:
text = log.read_text(encoding="utf-8", errors="replace")
except OSError as e:
print(f"Failed to read {log}: {e}", file=sys.stderr)
continue
metrics = _parse_serving_result_block(text)
bs, in_len, out_len = _extract_case_from_path(log)
results.append(BenchResult(path=log, bs=bs, in_len=in_len, out_len=out_len, metrics=metrics))
results.sort(key=lambda r: r.key())
if args.format == "markdown":
output_text = "\n".join([_md_line(r) for r in results]) + "\n"
elif args.format == "jsonl":
rows = []
for r in results:
rows.append(
{
"file": str(r.path),
"bs": r.bs,
"in_len": r.in_len,
"out_len": r.out_len,
"metrics": r.metrics,
}
)
output_text = "\n".join(json.dumps(row, ensure_ascii=False) for row in rows) + "\n"
else: # csv
# Flatten key metrics into columns.
fields = [
"file",
"bs",
"in_len",
"out_len",
"successful_requests",
"benchmark_duration_s",
"req_per_s",
"out_tok_per_s",
"total_tok_per_s",
"ttft_mean_ms",
"ttft_p99_ms",
"tpot_mean_ms",
"tpot_p99_ms",
"itl_mean_ms",
"itl_p99_ms",
]
key_map = {
"successful_requests": "Successful requests",
"benchmark_duration_s": "Benchmark duration (s)",
"req_per_s": "Request throughput (req/s)",
"out_tok_per_s": "Output token throughput (tok/s)",
"total_tok_per_s": "Total Token throughput (tok/s)",
"ttft_mean_ms": "Mean TTFT (ms)",
"ttft_p99_ms": "P99 TTFT (ms)",
"tpot_mean_ms": "Mean TPOT (ms)",
"tpot_p99_ms": "P99 TPOT (ms)",
"itl_mean_ms": "Mean ITL (ms)",
"itl_p99_ms": "P99 ITL (ms)",
}
buf = io.StringIO()
writer = csv.DictWriter(buf, fieldnames=fields)
writer.writeheader()
for r in results:
row = {
"file": str(r.path),
"bs": r.bs,
"in_len": r.in_len,
"out_len": r.out_len,
}
for out_key, metric_key in key_map.items():
row[out_key] = r.metrics.get(metric_key)
writer.writerow(row)
output_text = buf.getvalue()
if args.output:
out_path = Path(args.output).expanduser()
out_path.write_text(output_text, encoding="utf-8")
else:
sys.stdout.write(output_text)
return 0
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))
#include <torch/all.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <cstdint>
#include <optional>
// Forward declarations for fallback to existing vLLM kernels.
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);
void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
namespace vllm {
template <typename T, int WIDTH>
__device__ __forceinline__ T warp_reduce_sum_xor(T val) {
#pragma unroll
for (int mask = WIDTH / 2; mask > 0; mask >>= 1) {
val += __shfl_xor(val, mask);
}
return val;
}
template <typename T_ACC, typename scalar_t, int VEC_SIZE, bool HAS_RESIDUAL>
__device__ __forceinline__ T_ACC apply_residual_and_calc_sq(
scalar_t* r_data_low, scalar_t* r_data_high, scalar_t* res_head_ptr,
int offset_low, int offset_high) {
using LoadT = at::native::memory::aligned_vector<scalar_t, VEC_SIZE>;
if constexpr (HAS_RESIDUAL) {
scalar_t r_res_low[VEC_SIZE];
scalar_t r_res_high[VEC_SIZE];
*(LoadT*)r_res_low = *(LoadT*)(res_head_ptr + offset_low);
*(LoadT*)r_res_high = *(LoadT*)(res_head_ptr + offset_high);
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
r_res_low[i] = r_res_low[i] + r_data_low[i];
r_res_high[i] = r_res_high[i] + r_data_high[i];
r_data_low[i] = r_res_low[i];
r_data_high[i] = r_res_high[i];
}
*(LoadT*)(res_head_ptr + offset_low) = *(LoadT*)r_res_low;
*(LoadT*)(res_head_ptr + offset_high) = *(LoadT*)r_res_high;
}
T_ACC local_sum_sq = 0;
#pragma unroll VEC_SIZE
for (int i = 0; i < VEC_SIZE; ++i) {
T_ACC low = static_cast<T_ACC>(r_data_low[i]);
T_ACC high = static_cast<T_ACC>(r_data_high[i]);
local_sum_sq += low * low;
local_sum_sq += high * high;
}
return local_sum_sq;
}
#define DISPATCH_BOOL(VAL, NAME, ...) \
if (VAL) { \
constexpr bool NAME = true; \
__VA_ARGS__(); \
} else { \
constexpr bool NAME = false; \
__VA_ARGS__(); \
}
template <typename T_ACC, typename scalar_t, bool HAS_RESIDUAL, bool IS_NEOX,
int VEC_SIZE, int THREAD_PER_HEAD>
__global__ void opt_rms_rope_qwen3(
const int64_t* __restrict__ positions, scalar_t* __restrict__ query,
scalar_t* __restrict__ key, const scalar_t* __restrict__ cos_sin_cache,
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride_q, const int64_t head_stride_k,
const scalar_t* __restrict__ gamma_q,
const scalar_t* __restrict__ gamma_k, scalar_t* residual_q,
scalar_t* residual_k, const scalar_t eps, const int num_tokens,
const int num_heads, const int num_kv_heads, const int threads_per_token,
const int tokens_per_block) {
extern __shared__ char smem_buffer[];
scalar_t* s_cos_sin_base = reinterpret_cast<scalar_t*>(smem_buffer);
constexpr int HEAD_SIZE = 128;
constexpr int HALF_ROT = 64;
const int tid = threadIdx.x;
const int local_token_idx = tid / threads_per_token;
const int lane = tid % threads_per_token;
if (local_token_idx >= tokens_per_block) return;
const int global_token_idx = blockIdx.x * tokens_per_block + local_token_idx;
if (global_token_idx >= num_tokens) return;
scalar_t* my_s_cos_sin = s_cos_sin_base + local_token_idx * HEAD_SIZE;
const int64_t pos = positions[global_token_idx];
for (int i = lane; i < HEAD_SIZE; i += threads_per_token) {
my_s_cos_sin[i] = cos_sin_cache[pos * HEAD_SIZE + i];
}
__syncthreads();
const int q_boundary = num_heads * THREAD_PER_HEAD;
if (lane < q_boundary) {
const int q_head_idx = lane / THREAD_PER_HEAD;
const int q_lane_in_head = lane % THREAD_PER_HEAD;
scalar_t* q_head_ptr =
query + global_token_idx * query_stride + q_head_idx * head_stride_q;
scalar_t* res_q_head_ptr =
HAS_RESIDUAL
? (residual_q + global_token_idx * query_stride +
q_head_idx * head_stride_q)
: nullptr;
using LoadT = at::native::memory::aligned_vector<scalar_t, VEC_SIZE>;
scalar_t r_q_low[VEC_SIZE];
scalar_t r_q_high[VEC_SIZE];
const int offset_low = q_lane_in_head * VEC_SIZE;
const int offset_high = HALF_ROT + q_lane_in_head * VEC_SIZE;
*(LoadT*)r_q_low = *(LoadT*)(q_head_ptr + offset_low);
*(LoadT*)r_q_high = *(LoadT*)(q_head_ptr + offset_high);
T_ACC sum_sq =
apply_residual_and_calc_sq<T_ACC, scalar_t, VEC_SIZE, HAS_RESIDUAL>(
r_q_low, r_q_high, res_q_head_ptr, offset_low, offset_high);
sum_sq = warp_reduce_sum_xor<T_ACC, THREAD_PER_HEAD>(sum_sq);
const T_ACC inv_rms =
c10::cuda::compat::rsqrt(sum_sq / HEAD_SIZE + static_cast<T_ACC>(eps));
const scalar_t* cache_ptr = my_s_cos_sin;
if constexpr (IS_NEOX) {
scalar_t r_cos_low[VEC_SIZE], r_sin_low[VEC_SIZE];
*(LoadT*)r_cos_low = *(LoadT*)(cache_ptr + offset_low);
*(LoadT*)r_sin_low = *(LoadT*)(cache_ptr + rot_dim / 2 + offset_low);
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
r_q_low[i] = static_cast<T_ACC>(r_q_low[i]) * inv_rms *
static_cast<T_ACC>(gamma_q[offset_low + i]);
r_q_high[i] = static_cast<T_ACC>(r_q_high[i]) * inv_rms *
static_cast<T_ACC>(gamma_q[offset_high + i]);
const scalar_t q_l = r_q_low[i];
const scalar_t q_h = r_q_high[i];
const scalar_t c = r_cos_low[i];
const scalar_t s = r_sin_low[i];
r_q_low[i] = q_l * c - q_h * s;
r_q_high[i] = q_l * s + q_h * c;
}
} else {
using LoadCacheT =
at::native::memory::aligned_vector<scalar_t, VEC_SIZE / 2>;
scalar_t c_low[VEC_SIZE / 2], s_low[VEC_SIZE / 2];
scalar_t c_high[VEC_SIZE / 2], s_high[VEC_SIZE / 2];
const int cache_offset_low = offset_low / 2;
const int cache_offset_high = offset_high / 2;
*(LoadCacheT*)c_low = *(LoadCacheT*)(cache_ptr + cache_offset_low);
*(LoadCacheT*)s_low =
*(LoadCacheT*)(cache_ptr + rot_dim / 2 + cache_offset_low);
*(LoadCacheT*)c_high = *(LoadCacheT*)(cache_ptr + cache_offset_high);
*(LoadCacheT*)s_high =
*(LoadCacheT*)(cache_ptr + rot_dim / 2 + cache_offset_high);
#pragma unroll
for (int i = 0; i < VEC_SIZE; i += 2) {
const int c_idx = i / 2;
r_q_low[i] = static_cast<T_ACC>(r_q_low[i]) * inv_rms *
static_cast<T_ACC>(gamma_q[offset_low + i]);
r_q_low[i + 1] = static_cast<T_ACC>(r_q_low[i + 1]) * inv_rms *
static_cast<T_ACC>(gamma_q[offset_low + i + 1]);
const scalar_t q0 = r_q_low[i];
const scalar_t q1 = r_q_low[i + 1];
const scalar_t c = c_low[c_idx];
const scalar_t s = s_low[c_idx];
r_q_low[i] = q0 * c - q1 * s;
r_q_low[i + 1] = q1 * c + q0 * s;
r_q_high[i] = static_cast<T_ACC>(r_q_high[i]) * inv_rms *
static_cast<T_ACC>(gamma_q[offset_high + i]);
r_q_high[i + 1] = static_cast<T_ACC>(r_q_high[i + 1]) * inv_rms *
static_cast<T_ACC>(gamma_q[offset_high + i + 1]);
const scalar_t qh0 = r_q_high[i];
const scalar_t qh1 = r_q_high[i + 1];
const scalar_t ch = c_high[c_idx];
const scalar_t sh = s_high[c_idx];
r_q_high[i] = qh0 * ch - qh1 * sh;
r_q_high[i + 1] = qh1 * ch + qh0 * sh;
}
}
*(LoadT*)(q_head_ptr + offset_low) = *(LoadT*)r_q_low;
*(LoadT*)(q_head_ptr + offset_high) = *(LoadT*)r_q_high;
}
const int total_threads_needed = (num_heads + num_kv_heads) * THREAD_PER_HEAD;
if (lane >= q_boundary && lane < total_threads_needed && key != nullptr) {
const int k_lane_abs = lane - q_boundary;
const int kv_head_idx = k_lane_abs / THREAD_PER_HEAD;
const int k_lane_in_head = k_lane_abs % THREAD_PER_HEAD;
scalar_t* k_head_ptr =
key + global_token_idx * key_stride + kv_head_idx * head_stride_k;
scalar_t* res_k_head_ptr =
HAS_RESIDUAL
? (residual_k + global_token_idx * key_stride +
kv_head_idx * head_stride_k)
: nullptr;
using LoadTK = at::native::memory::aligned_vector<scalar_t, VEC_SIZE>;
scalar_t r_k_low[VEC_SIZE];
scalar_t r_k_high[VEC_SIZE];
const int offset_low = k_lane_in_head * VEC_SIZE;
const int offset_high = HALF_ROT + k_lane_in_head * VEC_SIZE;
*(LoadTK*)r_k_low = *(LoadTK*)(k_head_ptr + offset_low);
*(LoadTK*)r_k_high = *(LoadTK*)(k_head_ptr + offset_high);
T_ACC sum_sq_k =
apply_residual_and_calc_sq<T_ACC, scalar_t, VEC_SIZE, HAS_RESIDUAL>(
r_k_low, r_k_high, res_k_head_ptr, offset_low, offset_high);
sum_sq_k = warp_reduce_sum_xor<T_ACC, THREAD_PER_HEAD>(sum_sq_k);
const T_ACC inv_rms_k =
c10::cuda::compat::rsqrt(sum_sq_k / HEAD_SIZE + static_cast<T_ACC>(eps));
const scalar_t* cache_ptr_k = my_s_cos_sin;
if constexpr (IS_NEOX) {
scalar_t r_cos_low[VEC_SIZE], r_sin_low[VEC_SIZE];
scalar_t r_gamma_k_low[VEC_SIZE], r_gamma_k_high[VEC_SIZE];
*(LoadTK*)r_cos_low = *(LoadTK*)(cache_ptr_k + offset_low);
*(LoadTK*)r_sin_low = *(LoadTK*)(cache_ptr_k + rot_dim / 2 + offset_low);
*(LoadTK*)r_gamma_k_low = *(LoadTK*)(gamma_k + offset_low);
*(LoadTK*)r_gamma_k_high = *(LoadTK*)(gamma_k + offset_high);
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
r_k_low[i] = static_cast<T_ACC>(r_k_low[i]) * inv_rms_k *
static_cast<T_ACC>(r_gamma_k_low[i]);
r_k_high[i] = static_cast<T_ACC>(r_k_high[i]) * inv_rms_k *
static_cast<T_ACC>(r_gamma_k_high[i]);
const scalar_t k_l = r_k_low[i];
const scalar_t k_h = r_k_high[i];
const scalar_t c = r_cos_low[i];
const scalar_t s = r_sin_low[i];
r_k_low[i] = k_l * c - k_h * s;
r_k_high[i] = k_l * s + k_h * c;
}
} else {
using LoadCacheTK =
at::native::memory::aligned_vector<scalar_t, VEC_SIZE / 2>;
scalar_t r_cos_low[VEC_SIZE / 2], r_sin_low[VEC_SIZE / 2];
scalar_t r_cos_high[VEC_SIZE / 2], r_sin_high[VEC_SIZE / 2];
const int cache_offset_low = offset_low / 2;
const int cache_offset_high = offset_high / 2;
*(LoadCacheTK*)r_cos_low = *(LoadCacheTK*)(cache_ptr_k + cache_offset_low);
*(LoadCacheTK*)r_sin_low =
*(LoadCacheTK*)(cache_ptr_k + rot_dim / 2 + cache_offset_low);
*(LoadCacheTK*)r_cos_high =
*(LoadCacheTK*)(cache_ptr_k + cache_offset_high);
*(LoadCacheTK*)r_sin_high =
*(LoadCacheTK*)(cache_ptr_k + rot_dim / 2 + cache_offset_high);
#pragma unroll
for (int i = 0; i < VEC_SIZE; i += 2) {
const int c_idx = i / 2;
r_k_low[i] = static_cast<T_ACC>(r_k_low[i]) * inv_rms_k *
static_cast<T_ACC>(gamma_k[offset_low + i]);
r_k_low[i + 1] = static_cast<T_ACC>(r_k_low[i + 1]) * inv_rms_k *
static_cast<T_ACC>(gamma_k[offset_low + i + 1]);
const scalar_t k0 = r_k_low[i];
const scalar_t k1 = r_k_low[i + 1];
const scalar_t c = r_cos_low[c_idx];
const scalar_t s = r_sin_low[c_idx];
r_k_low[i] = k0 * c - k1 * s;
r_k_low[i + 1] = k1 * c + k0 * s;
r_k_high[i] = static_cast<T_ACC>(r_k_high[i]) * inv_rms_k *
static_cast<T_ACC>(gamma_k[offset_high + i]);
r_k_high[i + 1] = static_cast<T_ACC>(r_k_high[i + 1]) * inv_rms_k *
static_cast<T_ACC>(gamma_k[offset_high + i + 1]);
const scalar_t kh0 = r_k_high[i];
const scalar_t kh1 = r_k_high[i + 1];
const scalar_t ch = r_cos_high[c_idx];
const scalar_t sh = r_sin_high[c_idx];
r_k_high[i] = kh0 * ch - kh1 * sh;
r_k_high[i + 1] = kh1 * ch + kh0 * sh;
}
}
*(LoadTK*)(k_head_ptr + offset_low) = *(LoadTK*)r_k_low;
*(LoadTK*)(k_head_ptr + offset_high) = *(LoadTK*)r_k_high;
}
}
template <typename T_ACC, typename scalar_t>
void launch_opt_rms_rope(
const int64_t* positions, scalar_t* query, scalar_t* key,
const scalar_t* cos_sin_cache, const int rot_dim, const int64_t query_stride,
const int64_t key_stride, const int64_t head_stride_q,
const int64_t head_stride_k, const scalar_t* gamma_q,
const scalar_t* gamma_k, scalar_t* residual_q_ptr,
scalar_t* residual_k_ptr, const scalar_t eps, const int num_tokens,
const bool is_neox, const int num_heads, const int num_kv_heads,
const cudaStream_t stream) {
const bool has_residual =
(residual_q_ptr != nullptr && residual_k_ptr != nullptr);
constexpr int THREAD_PER_HEAD = 8;
constexpr int VEC = 8;
const int threads_per_token = (num_heads + num_kv_heads) * THREAD_PER_HEAD;
// Keep the same launch heuristic as the original kernel.
const int target_block_size = 256;
int tokens_per_block = target_block_size / threads_per_token;
if (tokens_per_block < 1) tokens_per_block = 1;
const int actual_block_size = tokens_per_block * threads_per_token;
const int grid_size = (num_tokens + tokens_per_block - 1) / tokens_per_block;
const size_t smem_size = tokens_per_block * 128 * sizeof(scalar_t);
DISPATCH_BOOL(has_residual, HAS_RESIDUAL_CONST, [&] {
DISPATCH_BOOL(is_neox, IS_NEOX_CONST, [&] {
opt_rms_rope_qwen3<T_ACC, scalar_t, HAS_RESIDUAL_CONST, IS_NEOX_CONST,
VEC, THREAD_PER_HEAD>
<<<grid_size, actual_block_size, smem_size, stream>>>(
positions, query, key, cos_sin_cache, rot_dim, query_stride,
key_stride, head_stride_q, head_stride_k, gamma_q, gamma_k,
residual_q_ptr, residual_k_ptr, eps, num_tokens, num_heads,
num_kv_heads, threads_per_token, tokens_per_block);
});
});
}
} // namespace vllm
void rms_rotary_embedding_fuse(
torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox, torch::Tensor& weight_q,
torch::Tensor& weight_k, std::optional<torch::Tensor> residual_q,
std::optional<torch::Tensor> residual_k, double epsilon) {
// Basic validation (mirrors rotary_embedding + layernorm checks).
const int64_t num_tokens = positions.numel();
const int positions_ndim = positions.dim();
TORCH_CHECK(positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens");
} else {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)) &&
query.size(1) == positions.size(1) &&
(!key.has_value() || key->size(1) == positions.size(1)),
"query, key and positions must have the same batch_size and seq_len");
}
TORCH_CHECK(query.is_cuda(), "query must be CUDA");
TORCH_CHECK(!key.has_value() || key->is_cuda(), "key must be CUDA");
TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA");
TORCH_CHECK(positions.is_cuda(), "positions must be CUDA");
TORCH_CHECK(weight_q.is_cuda() && weight_k.is_cuda(),
"weights must be CUDA");
TORCH_CHECK(query.is_contiguous(), "query must be contiguous");
TORCH_CHECK(!key.has_value() || key->is_contiguous(),
"key must be contiguous");
TORCH_CHECK(cos_sin_cache.is_contiguous(), "cos_sin_cache must be contiguous");
TORCH_CHECK(weight_q.is_contiguous() && weight_k.is_contiguous(),
"weights must be contiguous");
TORCH_CHECK(positions.scalar_type() == at::kLong,
"positions must be int64");
TORCH_CHECK(query.scalar_type() == cos_sin_cache.scalar_type(),
"cos_sin_cache must have same dtype as query");
TORCH_CHECK(weight_q.scalar_type() == query.scalar_type() &&
weight_k.scalar_type() == query.scalar_type(),
"weights must have same dtype as query");
TORCH_CHECK(!key.has_value() || key->scalar_type() == query.scalar_type(),
"key must have same dtype as query");
if (residual_q.has_value() || residual_k.has_value()) {
TORCH_CHECK(residual_q.has_value() && residual_k.has_value(),
"residual_q and residual_k must be both provided or both None");
TORCH_CHECK(residual_q->is_cuda() && residual_k->is_cuda(),
"residual tensors must be CUDA");
TORCH_CHECK(residual_q->is_contiguous() && residual_k->is_contiguous(),
"residual tensors must be contiguous");
TORCH_CHECK(residual_q->scalar_type() == query.scalar_type() &&
residual_k->scalar_type() == query.scalar_type(),
"residual tensors must have same dtype as query");
}
const int query_hidden_size = query.numel() / num_tokens;
const int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(!key.has_value() || (key_hidden_size % head_size == 0));
const int num_heads = query_hidden_size / head_size;
const int num_kv_heads = key.has_value() ? (key_hidden_size / head_size)
: num_heads;
TORCH_CHECK(num_heads % num_kv_heads == 0);
const int rot_dim = cos_sin_cache.size(1);
const int seq_dim_idx = positions_ndim - 1;
const int64_t query_stride = query.stride(seq_dim_idx);
const int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
const int query_ndim = query.dim();
const int64_t head_stride_q =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
const int64_t head_stride_k =
(key.has_value() && key->dim() == positions_ndim + 2) ? key->stride(-2)
: head_size;
const bool has_residual = residual_q.has_value() && residual_k.has_value();
const bool supports_qwen3_opt =
(key.has_value() && head_size == 128 && rot_dim == 128 &&
(num_heads + num_kv_heads) <= 128 && weight_q.numel() == 128 &&
weight_k.numel() == 128);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (supports_qwen3_opt) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, query.scalar_type(),
"vllm_rms_rotary_embedding_fuse_qwen3", [&] {
using T_ACC = at::acc_type<scalar_t, true>;
scalar_t* res_q_ptr =
has_residual ? residual_q->data_ptr<scalar_t>() : nullptr;
scalar_t* res_k_ptr =
has_residual ? residual_k->data_ptr<scalar_t>() : nullptr;
vllm::launch_opt_rms_rope<T_ACC, scalar_t>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key->data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, head_stride_q, head_stride_k,
weight_q.data_ptr<scalar_t>(), weight_k.data_ptr<scalar_t>(),
res_q_ptr, res_k_ptr, static_cast<scalar_t>(epsilon),
static_cast<int>(num_tokens), is_neox, num_heads, num_kv_heads,
stream);
});
return;
}
// Fallback: use existing kernels (still removes lightop dependency).
// Apply per-head RMSNorm to Q/K and then call the existing RoPE kernel.
{
TORCH_CHECK(weight_q.numel() == head_size && weight_k.numel() == head_size,
"weight_q/weight_k must have shape [head_size]");
auto q_heads = query.view({num_tokens * num_heads, head_size});
if (has_residual) {
auto rq_heads = residual_q->view({num_tokens * num_heads, head_size});
fused_add_rms_norm_opt(q_heads, rq_heads, weight_q, epsilon);
} else {
rms_norm_opt(q_heads, q_heads, weight_q, epsilon);
}
if (key.has_value()) {
auto k_heads = key->view({num_tokens * num_kv_heads, head_size});
if (has_residual) {
auto rk_heads =
residual_k->view({num_tokens * num_kv_heads, head_size});
fused_add_rms_norm_opt(k_heads, rk_heads, weight_k, epsilon);
} else {
rms_norm_opt(k_heads, k_heads, weight_k, epsilon);
}
}
}
rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox);
}
...@@ -126,6 +126,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ...@@ -126,6 +126,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
void rms_rotary_embedding_fuse(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key,
int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox,
torch::Tensor& weight_q,
torch::Tensor& weight_k,
std::optional<torch::Tensor> residual_q,
std::optional<torch::Tensor> residual_k,
double epsilon);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, // void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
......
...@@ -245,6 +245,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -245,6 +245,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"); " Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Fused RMSNorm + RoPE (in-place on query/key).
ops.def(
"rms_rotary_embedding_fuse(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" Tensor weight_q, Tensor weight_k,"
" Tensor!? residual_q, Tensor!? residual_k,"
" float epsilon) -> ()");
ops.impl("rms_rotary_embedding_fuse", torch::kCUDA,
&rms_rotary_embedding_fuse);
// trans w16 // trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()"); ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm); ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
......
...@@ -244,6 +244,11 @@ if TYPE_CHECKING: ...@@ -244,6 +244,11 @@ if TYPE_CHECKING:
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
# Backend for fused RMS+RoPE in Qwen3/Qwen3-MoE:
# "vllm" -> use `torch.ops._C.rms_rotary_embedding_fuse`
# "lightop"-> use `lightop.rms_rotary_embedding_fuse`
# "auto" -> prefer vllm, fallback to lightop
VLLM_FUSED_RMS_ROPE_BACKEND: str = "auto"
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_W8A8_BACKEND: int = 3 VLLM_W8A8_BACKEND: int = 3
...@@ -1689,6 +1694,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1689,6 +1694,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE": "VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")), ("true", "1")),
# Backend for fused RMS + RoPE (Qwen3/Qwen3-MoE):
# auto: prefer vllm, fallback to lightop
# vllm: use `torch.ops._C.rms_rotary_embedding_fuse`
# lightop: use `lightop.rms_rotary_embedding_fuse`
"VLLM_FUSED_RMS_ROPE_BACKEND":
lambda: os.getenv("VLLM_FUSED_RMS_ROPE_BACKEND", "auto").lower(),
# vLLM will use fast token id copy # vLLM will use fast token id copy
"VLLM_V1_FAST_TOKEN_ID_COPY": "VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
......
...@@ -152,8 +152,55 @@ class Qwen3Attention(nn.Module): ...@@ -152,8 +152,55 @@ class Qwen3Attention(nn.Module):
k_bias: Optional[torch.Tensor], k_bias: Optional[torch.Tensor],
epsilon: float, epsilon: float,
) -> None: ) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel backend = envs.VLLM_FUSED_RMS_ROPE_BACKEND
fused_kernel( if backend == "lightop":
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
if backend not in ("vllm", "auto"):
raise ValueError(
"VLLM_FUSED_RMS_ROPE_BACKEND must be one of "
"('auto', 'vllm', 'lightop'), got: %r" % backend)
# Ensure vLLM extension ops are loaded before checking/calling them.
try:
import vllm._C # noqa: F401
except Exception:
if backend == "vllm":
raise
if backend == "auto" and not hasattr(torch.ops._C,
"rms_rotary_embedding_fuse"):
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
torch.ops._C.rms_rotary_embedding_fuse(
positions, positions,
query, query,
key, key,
......
...@@ -284,8 +284,54 @@ class Qwen3MoeAttention(nn.Module): ...@@ -284,8 +284,54 @@ class Qwen3MoeAttention(nn.Module):
k_bias: Optional[torch.Tensor], k_bias: Optional[torch.Tensor],
epsilon: float, epsilon: float,
) -> None: ) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel backend = envs.VLLM_FUSED_RMS_ROPE_BACKEND
fused_kernel( if backend == "lightop":
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
if backend not in ("vllm", "auto"):
raise ValueError(
"VLLM_FUSED_RMS_ROPE_BACKEND must be one of "
"('auto', 'vllm', 'lightop'), got: %r" % backend)
try:
import vllm._C # noqa: F401
except Exception:
if backend == "vllm":
raise
if backend == "auto" and not hasattr(torch.ops._C,
"rms_rotary_embedding_fuse"):
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
torch.ops._C.rms_rotary_embedding_fuse(
positions, positions,
query, query,
key, key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment