"vscode:/vscode.git/clone" did not exist on "52f58fc42ab1f00ae3d0e0279594664c07504142"
Unverified Commit 11965b0d authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix sgl-kernel benchmark dead code (#11022)

parent 71959545
import itertools
import os
from typing import Optional, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
# Optional vLLM import
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
from sglang.srt.utils import is_hip
_is_hip = is_hip()
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# Get correct FP8 E4M3 maximum value
......@@ -49,6 +65,9 @@ def torch_per_token_quant_fp8(
def vllm_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_per_token_quant_fp8(input)
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
......@@ -74,6 +93,17 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping vLLM comparison")
# Only compare Torch vs SGLang
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
torch_sglang_out_diff = (
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
)
print(f"Scale difference (Torch vs SGLang): {torch_sglang_scale_diff:.8f}")
print(f"Output difference (Torch vs SGLang): {torch_sglang_out_diff:.8f}")
return
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
# Compare scales
......@@ -125,9 +155,15 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
hidden_dim_range = [1368, 2048, 4096]
# CI environment uses simplified parameters
if IS_CI:
batch_size_range = [16] # Single batch size for CI
seq_len_range = [64] # Single sequence length for CI
hidden_dim_range = [2048] # Single hidden dimension for CI
else:
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
hidden_dim_range = [1368, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
......@@ -137,9 +173,19 @@ configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_ran
x_names=["batch_size", "seq_len", "hidden_dim"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "vllm", "sglang"],
line_names=["Torch Reference", "VLLM", "SGL Kernel"],
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
line_vals=(
["torch", "vllm", "sglang"] if VLLM_AVAILABLE else ["torch", "sglang"]
),
line_names=(
["Torch Reference", "VLLM", "SGL Kernel"]
if VLLM_AVAILABLE
else ["Torch Reference", "SGL Kernel"]
),
styles=(
[("red", "-"), ("blue", "-"), ("green", "-")]
if VLLM_AVAILABLE
else [("red", "-"), ("green", "-")]
),
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
......@@ -156,6 +202,8 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_per_token_quant_fp8(x.clone())
elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone())
......@@ -166,11 +214,16 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if __name__ == "__main__":
# Test various hidden dimensions for correctness
test_dims = [1368, 2048, 4096]
# Test various hidden dimensions for correctness - simplified for CI
if IS_CI:
test_dims = [2048] # Single dimension for CI
batch_size, seq_len = 4, 64 # Smaller values for CI
else:
test_dims = [1368, 2048, 4096]
batch_size, seq_len = 4, 4096
for dim in test_dims:
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)
calculate_diff(batch_size=batch_size, seq_len=seq_len, hidden_dim=dim)
print("\n" + "=" * 60)
print("Starting performance benchmark...")
......
import argparse
import copy
import itertools
import os
import torch
import triton
......@@ -10,6 +11,12 @@ from sgl_kernel import (
qserve_w4a8_per_group_gemm,
)
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
......@@ -65,10 +72,17 @@ WEIGHT_SHAPES = {
}
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 16] # Simplified for CI
else:
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_vals=batch_sizes,
x_log=False,
line_arg="provider",
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
......@@ -184,13 +198,19 @@ if __name__ == "__main__":
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
)
print("Benchmark finished!")
# Skip in CI environment
if IS_CI:
print("Skipping QServe W4A8 GEMM benchmark in CI environment")
print("QServe operations may have compatibility issues in CI")
else:
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
)
print("Benchmark finished!")
......@@ -2,6 +2,7 @@
# (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse
import itertools
import os
import re
from typing import List, Optional, Tuple, Union
......@@ -10,9 +11,31 @@ import torch
import torch.nn as nn
import triton
import triton.testing
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from sgl_kernel.utils import is_arch_support_pdl
from vllm import _custom_ops as vllm_ops
# Optional imports
try:
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
FLASHINFER_AVAILABLE = True
except ImportError:
fused_add_rmsnorm = None
rmsnorm = None
FLASHINFER_AVAILABLE = False
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def str2int_list(arg: str) -> List[int]:
......@@ -79,6 +102,10 @@ def rmsnorm_flashinfer(
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
if not FLASHINFER_AVAILABLE:
# Fallback to naive implementation if FlashInfer is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
......@@ -103,6 +130,10 @@ def rmsnorm_vllm(
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
if not VLLM_AVAILABLE:
# Fallback to naive implementation if vLLM is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
......@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_sglang = output_sglang[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
if FLASHINFER_AVAILABLE:
print(f"FlashInfer output={output_flashinfer}")
else:
print("FlashInfer not available, skipped")
if VLLM_AVAILABLE:
print(f"VLLM output={output_vllm}")
else:
print("vLLM not available, skipped")
print(f"SGLang output={output_sglang}")
if (
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
# Only compare available implementations
all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
if FLASHINFER_AVAILABLE:
all_match = all_match and torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
)
if VLLM_AVAILABLE:
all_match = all_match and torch.allclose(
output_naive, output_vllm, atol=1e-2, rtol=1e-2
)
if all_match:
print("✅ All available implementations match")
else:
print("❌ Implementations differ")
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
# CI environment uses simplified parameters
if IS_CI:
default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [64] # Single sequence length for CI
default_hidden_sizes = [4096] # Single hidden size for CI
else:
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
return list(itertools.product(bsizes, slens, hsizes))
# Filter providers based on availability
available_providers = ["huggingface", "sglang"]
available_names = ["HuggingFace", "SGL Kernel"]
available_styles = [("blue", "-"), ("orange", "-")]
if FLASHINFER_AVAILABLE:
available_providers.insert(-1, "flashinfer")
available_names.insert(-1, "FlashInfer")
available_styles.insert(-1, ("green", "-"))
if VLLM_AVAILABLE:
available_providers.insert(-1, "vllm")
available_names.insert(-1, "vLLM")
available_styles.insert(-1, ("red", "-"))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_size"],
x_vals=[],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
line_vals=available_providers,
line_names=available_names,
styles=available_styles,
ylabel="µs (median) or × (speed-up)",
plot_name="rmsnorm-performance",
args={},
......@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
)
elif provider == "flashinfer":
if not FLASHINFER_AVAILABLE:
return (0, 0, 0)
return timed(
lambda: rmsnorm_flashinfer(
x.clone(),
......@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
)
elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
return timed(
lambda: rmsnorm_vllm(
x.clone(),
......@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
)
# provider == "speedup"
t_ref, _, _ = timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
if VLLM_AVAILABLE:
t_ref, _, _ = timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
else:
t_ref, _, _ = timed(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
)
t_sgl, _, _ = timed(
lambda: rmsnorm_sglang(
x.clone(),
......@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
residual.clone() if residual is not None else None,
)
)
spd = t_ref / t_sgl
spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd)
......
import itertools
import os
import torch
import triton
......@@ -12,17 +13,31 @@ from sgl_kernel.testing.rotary_embedding import (
from sglang.srt.bench_utils import bench_kineto
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in (
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# CI environment uses simplified parameters
if IS_CI:
batch_seq_configs = [(1, 1)] # Single config for CI
save_kv_configs = [False] # Single option for CI
else:
batch_seq_configs = [
(1, 1),
(32, 1),
(128, 1),
(512, 1),
(2, 512),
(4, 4096),
)
for save_kv_cache in (False, True)
]
save_kv_configs = [False, True]
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in batch_seq_configs
for save_kv_cache in save_kv_configs
]
......
import itertools
import os
import sgl_kernel
import torch
import triton
import triton.testing
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k, top_p, eps=1e-4
......@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p):
)
# parameter space
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
# parameter space - simplified for CI
if IS_CI:
batch_size_range = [16] # Single batch size for CI
vocab_size_range = [111] # Single vocab size for CI
p_range = [0.1] # Single p value for CI
else:
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
......@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order="joint",
)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
# Correctness check
for cfg in configs:
# Correctness check - simplified for CI
if IS_CI:
# Only test one configuration in CI
test_configs = [configs[0]] if configs else [(16, 111, 0.1)]
else:
test_configs = configs
for cfg in test_configs:
calculate_diff(*cfg)
print("\n" + "=" * 60)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment