"vscode:/vscode.git/clone" did not exist on "aa49f148322a39727be110da51a6782d43a2f5d8"
Commit 588538f5 authored by laibao's avatar laibao
Browse files

• feat(qwen3):新增 vLLM 内置 RMS+RoPE 融合算子,并支持 LightOp 后端切换

  - 在 vLLM _C 扩展中新增 rms_rotary_embedding_fuse(注册 op + CUDA kernel),减少对 LightOp 的硬依赖
  - 新增环境变量 VLLM_FUSED_RMS_ROPE_BACKEND=auto|vllm|lightop,auto 优先走 vLLM,缺失时回退 LightOp
  - 更新 Qwen3 / Qwen3-MoE 的 fused 路径按后端选择执行
  - 补充 tc_opt benchmark 结果解析脚本 benchmarks/tc_opt/test/parse_bench_results.py
parent 70506d98
......@@ -256,6 +256,7 @@ set(VLLM_EXT_SRC
"csrc/attention/merge_attn_states.cu"
"csrc/attention/vertical_slash_index.cu"
"csrc/pos_encoding_kernels.cu"
"csrc/fuse_rms_rope_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu"
......@@ -978,4 +979,4 @@ if (VLLM_GPU_LANG STREQUAL "CUDA")
# vllm-flash-attn should be last as it overwrites some CMake functions
include(cmake/external_projects/vllm_flash_attn.cmake)
endif ()
\ No newline at end of file
endif ()
#!/usr/bin/env python3
import argparse
import csv
import io
import json
import re
import sys
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
_FILENAME_RE = re.compile(r"_bs(?P<bs>\d+)_in(?P<in_len>\d+)_out(?P<out_len>\d+)\.(?P<ext>log|txt)$")
_BLOCK_START = "============ Serving Benchmark Result ============"
_BLOCK_END = "=================================================="
def _try_parse_number(value):
v = value.strip()
if not v:
return v
m = re.match(r"^-?\d+$", v)
if m:
try:
return int(v)
except ValueError:
return v
m = re.match(r"^-?\d+(?:\.\d+)?$", v)
if m:
try:
return float(v)
except ValueError:
return v
return v
def _parse_serving_result_block(text):
lines = text.splitlines()
try:
start_idx = lines.index(_BLOCK_START)
except ValueError:
return {}
end_idx = None
for i in range(start_idx + 1, len(lines)):
if lines[i].strip() == _BLOCK_END:
end_idx = i
break
if end_idx is None:
end_idx = len(lines)
metrics = {}
for raw in lines[start_idx + 1 : end_idx]:
if ":" not in raw:
continue
key, value = raw.split(":", 1)
key = key.strip()
value = value.strip()
if not key or not value:
continue
# Values are padded; take the first token if it looks numeric.
first = value.split()[0]
metrics[key] = _try_parse_number(first)
return metrics
def _extract_case_from_path(path):
m = _FILENAME_RE.search(path.name)
if not m:
return (None, None, None)
return (int(m.group("bs")), int(m.group("in_len")), int(m.group("out_len")))
class BenchResult(object):
def __init__(self, path, bs, in_len, out_len, metrics):
self.path = path
self.bs = bs
self.in_len = in_len
self.out_len = out_len
self.metrics = metrics
def key(self):
bs = self.bs if self.bs is not None else 0
in_len = self.in_len if self.in_len is not None else 0
out_len = self.out_len if self.out_len is not None else 0
return (bs, in_len, out_len, self.path.name)
def _find_logs(paths):
logs = []
for p in paths:
if p.is_dir():
logs.extend(sorted(p.glob("bench_*.log")))
else:
logs.append(p)
return [p for p in logs if p.exists()]
def _fmt_float(v):
if isinstance(v, float):
return f"{v:.2f}"
if isinstance(v, int):
return str(v)
if v is None:
return "NA"
return str(v)
def _md_line(r):
m = r.metrics
return (
f"- bs={r.bs} in={r.in_len} out={r.out_len}: "
f"req/s={_fmt_float(m.get('Request throughput (req/s)'))}, "
f"out_tok/s={_fmt_float(m.get('Output token throughput (tok/s)'))}, "
f"TTFT mean/p99={_fmt_float(m.get('Mean TTFT (ms)'))}/{_fmt_float(m.get('P99 TTFT (ms)'))} ms, "
f"TPOT mean/p99={_fmt_float(m.get('Mean TPOT (ms)'))}/{_fmt_float(m.get('P99 TPOT (ms)'))} ms, "
f"ITL mean/p99={_fmt_float(m.get('Mean ITL (ms)'))}/{_fmt_float(m.get('P99 ITL (ms)'))} ms"
)
def main(argv):
parser = argparse.ArgumentParser(
description="Parse vLLM benchmark-serving bench_*.log files and output a request-result list."
)
parser.add_argument(
"paths",
nargs="+",
help="One or more bench_*.log files or a directory containing them.",
)
parser.add_argument(
"--format",
choices=("markdown", "csv", "jsonl"),
default="markdown",
help="Output format (default: markdown).",
)
parser.add_argument(
"--output",
help="Write output to a file instead of stdout.",
)
args = parser.parse_args(argv)
input_paths = [Path(p).expanduser() for p in args.paths]
logs = _find_logs(input_paths)
if not logs:
print("No bench_*.log files found.", file=sys.stderr)
return 2
results = []
for log in logs:
try:
text = log.read_text(encoding="utf-8", errors="replace")
except OSError as e:
print(f"Failed to read {log}: {e}", file=sys.stderr)
continue
metrics = _parse_serving_result_block(text)
bs, in_len, out_len = _extract_case_from_path(log)
results.append(BenchResult(path=log, bs=bs, in_len=in_len, out_len=out_len, metrics=metrics))
results.sort(key=lambda r: r.key())
if args.format == "markdown":
output_text = "\n".join([_md_line(r) for r in results]) + "\n"
elif args.format == "jsonl":
rows = []
for r in results:
rows.append(
{
"file": str(r.path),
"bs": r.bs,
"in_len": r.in_len,
"out_len": r.out_len,
"metrics": r.metrics,
}
)
output_text = "\n".join(json.dumps(row, ensure_ascii=False) for row in rows) + "\n"
else: # csv
# Flatten key metrics into columns.
fields = [
"file",
"bs",
"in_len",
"out_len",
"successful_requests",
"benchmark_duration_s",
"req_per_s",
"out_tok_per_s",
"total_tok_per_s",
"ttft_mean_ms",
"ttft_p99_ms",
"tpot_mean_ms",
"tpot_p99_ms",
"itl_mean_ms",
"itl_p99_ms",
]
key_map = {
"successful_requests": "Successful requests",
"benchmark_duration_s": "Benchmark duration (s)",
"req_per_s": "Request throughput (req/s)",
"out_tok_per_s": "Output token throughput (tok/s)",
"total_tok_per_s": "Total Token throughput (tok/s)",
"ttft_mean_ms": "Mean TTFT (ms)",
"ttft_p99_ms": "P99 TTFT (ms)",
"tpot_mean_ms": "Mean TPOT (ms)",
"tpot_p99_ms": "P99 TPOT (ms)",
"itl_mean_ms": "Mean ITL (ms)",
"itl_p99_ms": "P99 ITL (ms)",
}
buf = io.StringIO()
writer = csv.DictWriter(buf, fieldnames=fields)
writer.writeheader()
for r in results:
row = {
"file": str(r.path),
"bs": r.bs,
"in_len": r.in_len,
"out_len": r.out_len,
}
for out_key, metric_key in key_map.items():
row[out_key] = r.metrics.get(metric_key)
writer.writerow(row)
output_text = buf.getvalue()
if args.output:
out_path = Path(args.output).expanduser()
out_path.write_text(output_text, encoding="utf-8")
else:
sys.stdout.write(output_text)
return 0
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))
This diff is collapsed.
......@@ -126,6 +126,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
void rms_rotary_embedding_fuse(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key,
int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox,
torch::Tensor& weight_q,
torch::Tensor& weight_k,
std::optional<torch::Tensor> residual_q,
std::optional<torch::Tensor> residual_k,
double epsilon);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
......
......@@ -245,6 +245,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// Fused RMSNorm + RoPE (in-place on query/key).
ops.def(
"rms_rotary_embedding_fuse(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" Tensor weight_q, Tensor weight_k,"
" Tensor!? residual_q, Tensor!? residual_k,"
" float epsilon) -> ()");
ops.impl("rms_rotary_embedding_fuse", torch::kCUDA,
&rms_rotary_embedding_fuse);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
......
......@@ -244,6 +244,11 @@ if TYPE_CHECKING:
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = False
VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False
# Backend for fused RMS+RoPE in Qwen3/Qwen3-MoE:
# "vllm" -> use `torch.ops._C.rms_rotary_embedding_fuse`
# "lightop"-> use `lightop.rms_rotary_embedding_fuse`
# "auto" -> prefer vllm, fallback to lightop
VLLM_FUSED_RMS_ROPE_BACKEND: str = "auto"
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_W8A8_BACKEND: int = 3
......@@ -1689,6 +1694,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_RMS_ROPE":
lambda: (os.environ.get("VLLM_USE_FUSED_RMS_ROPE", "False").lower() in
("true", "1")),
# Backend for fused RMS + RoPE (Qwen3/Qwen3-MoE):
# auto: prefer vllm, fallback to lightop
# vllm: use `torch.ops._C.rms_rotary_embedding_fuse`
# lightop: use `lightop.rms_rotary_embedding_fuse`
"VLLM_FUSED_RMS_ROPE_BACKEND":
lambda: os.getenv("VLLM_FUSED_RMS_ROPE_BACKEND", "auto").lower(),
# vLLM will use fast token id copy
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
......
......@@ -152,8 +152,55 @@ class Qwen3Attention(nn.Module):
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
backend = envs.VLLM_FUSED_RMS_ROPE_BACKEND
if backend == "lightop":
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
if backend not in ("vllm", "auto"):
raise ValueError(
"VLLM_FUSED_RMS_ROPE_BACKEND must be one of "
"('auto', 'vllm', 'lightop'), got: %r" % backend)
# Ensure vLLM extension ops are loaded before checking/calling them.
try:
import vllm._C # noqa: F401
except Exception:
if backend == "vllm":
raise
if backend == "auto" and not hasattr(torch.ops._C,
"rms_rotary_embedding_fuse"):
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
torch.ops._C.rms_rotary_embedding_fuse(
positions,
query,
key,
......
......@@ -284,8 +284,54 @@ class Qwen3MoeAttention(nn.Module):
k_bias: Optional[torch.Tensor],
epsilon: float,
) -> None:
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
backend = envs.VLLM_FUSED_RMS_ROPE_BACKEND
if backend == "lightop":
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
if backend not in ("vllm", "auto"):
raise ValueError(
"VLLM_FUSED_RMS_ROPE_BACKEND must be one of "
"('auto', 'vllm', 'lightop'), got: %r" % backend)
try:
import vllm._C # noqa: F401
except Exception:
if backend == "vllm":
raise
if backend == "auto" and not hasattr(torch.ops._C,
"rms_rotary_embedding_fuse"):
from lightop import rms_rotary_embedding_fuse as fused_kernel
fused_kernel(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style,
q_weight,
k_weight,
q_bias,
k_bias,
epsilon,
)
return
torch.ops._C.rms_rotary_embedding_fuse(
positions,
query,
key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment