Unverified Commit 11965b0d authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix sgl-kernel benchmark dead code (#11022)

parent 71959545
import itertools import itertools
import os
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8 from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
# Optional vLLM import
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# Get correct FP8 E4M3 maximum value # Get correct FP8 E4M3 maximum value
...@@ -49,6 +65,9 @@ def torch_per_token_quant_fp8( ...@@ -49,6 +65,9 @@ def torch_per_token_quant_fp8(
def vllm_per_token_quant_fp8( def vllm_per_token_quant_fp8(
input: torch.Tensor, input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_per_token_quant_fp8(input)
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True) return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
...@@ -74,6 +93,17 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int): ...@@ -74,6 +93,17 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping vLLM comparison")
# Only compare Torch vs SGLang
torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
torch_sglang_out_diff = (
torch.abs(torch_out.float() - sglang_out.float()).mean().item()
)
print(f"Scale difference (Torch vs SGLang): {torch_sglang_scale_diff:.8f}")
print(f"Output difference (Torch vs SGLang): {torch_sglang_out_diff:.8f}")
return
print(f"\n=== Comparison for hidden_dim={hidden_dim} ===") print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
# Compare scales # Compare scales
...@@ -125,9 +155,15 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int): ...@@ -125,9 +155,15 @@ def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}") print(f" VLLM vs SGLang: {'✅' if vllm_sglang_match else '❌'}")
batch_size_range = [16, 32, 64, 128] # CI environment uses simplified parameters
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096] if IS_CI:
hidden_dim_range = [1368, 2048, 4096] batch_size_range = [16] # Single batch size for CI
seq_len_range = [64] # Single sequence length for CI
hidden_dim_range = [2048] # Single hidden dimension for CI
else:
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
hidden_dim_range = [1368, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range)) configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
...@@ -137,9 +173,19 @@ configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_ran ...@@ -137,9 +173,19 @@ configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_ran
x_names=["batch_size", "seq_len", "hidden_dim"], x_names=["batch_size", "seq_len", "hidden_dim"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["torch", "vllm", "sglang"], line_vals=(
line_names=["Torch Reference", "VLLM", "SGL Kernel"], ["torch", "vllm", "sglang"] if VLLM_AVAILABLE else ["torch", "sglang"]
styles=[("red", "-"), ("blue", "-"), ("green", "-")], ),
line_names=(
["Torch Reference", "VLLM", "SGL Kernel"]
if VLLM_AVAILABLE
else ["Torch Reference", "SGL Kernel"]
),
styles=(
[("red", "-"), ("blue", "-"), ("green", "-")]
if VLLM_AVAILABLE
else [("red", "-"), ("green", "-")]
),
ylabel="us", ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance", plot_name="per-token-dynamic-quant-fp8-performance",
args={}, args={},
...@@ -156,6 +202,8 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider): ...@@ -156,6 +202,8 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if provider == "torch": if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone()) fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_per_token_quant_fp8(x.clone()) fn = lambda: vllm_per_token_quant_fp8(x.clone())
elif provider == "sglang": elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone()) fn = lambda: sglang_per_token_quant_fp8(x.clone())
...@@ -166,11 +214,16 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider): ...@@ -166,11 +214,16 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
if __name__ == "__main__": if __name__ == "__main__":
# Test various hidden dimensions for correctness # Test various hidden dimensions for correctness - simplified for CI
test_dims = [1368, 2048, 4096] if IS_CI:
test_dims = [2048] # Single dimension for CI
batch_size, seq_len = 4, 64 # Smaller values for CI
else:
test_dims = [1368, 2048, 4096]
batch_size, seq_len = 4, 4096
for dim in test_dims: for dim in test_dims:
calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim) calculate_diff(batch_size=batch_size, seq_len=seq_len, hidden_dim=dim)
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("Starting performance benchmark...") print("Starting performance benchmark...")
......
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
...@@ -10,6 +11,12 @@ from sgl_kernel import ( ...@@ -10,6 +11,12 @@ from sgl_kernel import (
qserve_w4a8_per_group_gemm, qserve_w4a8_per_group_gemm,
) )
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
...@@ -65,10 +72,17 @@ WEIGHT_SHAPES = { ...@@ -65,10 +72,17 @@ WEIGHT_SHAPES = {
} }
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 16] # Simplified for CI
else:
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], x_vals=batch_sizes,
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"], line_vals=["FP16", "W8A8", "Qserve_W4A8_Per_Channel", "Qserve_W4A8_Per_Group"],
...@@ -184,13 +198,19 @@ if __name__ == "__main__": ...@@ -184,13 +198,19 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
KN_model_names = prepare_shapes(args) # Skip in CI environment
for K, N, model_name in KN_model_names: if IS_CI:
print(f"{model_name} N={N} K={K}: ") print("Skipping QServe W4A8 GEMM benchmark in CI environment")
benchmark.run( print("QServe operations may have compatibility issues in CI")
print_data=True, else:
N=N, KN_model_names = prepare_shapes(args)
K=K,
) for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
print("Benchmark finished!") benchmark.run(
print_data=True,
N=N,
K=K,
)
print("Benchmark finished!")
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# (batch_size, seq_len, hidden_size) and prints speed-up. # (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse import argparse
import itertools import itertools
import os
import re import re
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -10,9 +11,31 @@ import torch ...@@ -10,9 +11,31 @@ import torch
import torch.nn as nn import torch.nn as nn
import triton import triton
import triton.testing import triton.testing
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from sgl_kernel.utils import is_arch_support_pdl from sgl_kernel.utils import is_arch_support_pdl
from vllm import _custom_ops as vllm_ops
# Optional imports
try:
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
FLASHINFER_AVAILABLE = True
except ImportError:
fused_add_rmsnorm = None
rmsnorm = None
FLASHINFER_AVAILABLE = False
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def str2int_list(arg: str) -> List[int]: def str2int_list(arg: str) -> List[int]:
...@@ -79,6 +102,10 @@ def rmsnorm_flashinfer( ...@@ -79,6 +102,10 @@ def rmsnorm_flashinfer(
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
if not FLASHINFER_AVAILABLE:
# Fallback to naive implementation if FlashInfer is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape orig_shape = x.shape
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
if residual is not None: if residual is not None:
...@@ -103,6 +130,10 @@ def rmsnorm_vllm( ...@@ -103,6 +130,10 @@ def rmsnorm_vllm(
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
eps: float = 1e-6, eps: float = 1e-6,
): ):
if not VLLM_AVAILABLE:
# Fallback to naive implementation if vLLM is not available
return rmsnorm_naive(x, weight, residual, eps)
orig_shape = x.shape orig_shape = x.shape
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
if residual is not None: if residual is not None:
...@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): ...@@ -179,37 +210,72 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_sglang = output_sglang[0] output_sglang = output_sglang[0]
print(f"Naive output={output_naive}") print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}") if FLASHINFER_AVAILABLE:
print(f"VLLM output={output_vllm}") print(f"FlashInfer output={output_flashinfer}")
else:
print("FlashInfer not available, skipped")
if VLLM_AVAILABLE:
print(f"VLLM output={output_vllm}")
else:
print("vLLM not available, skipped")
print(f"SGLang output={output_sglang}") print(f"SGLang output={output_sglang}")
if ( # Only compare available implementations
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2) all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2) if FLASHINFER_AVAILABLE:
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2) all_match = all_match and torch.allclose(
): output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
print("✅ All implementations match") )
if VLLM_AVAILABLE:
all_match = all_match and torch.allclose(
output_naive, output_vllm, atol=1e-2, rtol=1e-2
)
if all_match:
print("✅ All available implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64 # CI environment uses simplified parameters
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024 if IS_CI:
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144 default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [64] # Single sequence length for CI
default_hidden_sizes = [4096] # Single hidden size for CI
else:
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]: def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
return list(itertools.product(bsizes, slens, hsizes)) return list(itertools.product(bsizes, slens, hsizes))
# Filter providers based on availability
available_providers = ["huggingface", "sglang"]
available_names = ["HuggingFace", "SGL Kernel"]
available_styles = [("blue", "-"), ("orange", "-")]
if FLASHINFER_AVAILABLE:
available_providers.insert(-1, "flashinfer")
available_names.insert(-1, "FlashInfer")
available_styles.insert(-1, ("green", "-"))
if VLLM_AVAILABLE:
available_providers.insert(-1, "vllm")
available_names.insert(-1, "vLLM")
available_styles.insert(-1, ("red", "-"))
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_size"], x_names=["batch_size", "seq_len", "hidden_size"],
x_vals=[], x_vals=[],
line_arg="provider", line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm", "sglang"], line_vals=available_providers,
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"], line_names=available_names,
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")], styles=available_styles,
ylabel="µs (median) or × (speed-up)", ylabel="µs (median) or × (speed-up)",
plot_name="rmsnorm-performance", plot_name="rmsnorm-performance",
args={}, args={},
...@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): ...@@ -242,6 +308,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
) )
) )
elif provider == "flashinfer": elif provider == "flashinfer":
if not FLASHINFER_AVAILABLE:
return (0, 0, 0)
return timed( return timed(
lambda: rmsnorm_flashinfer( lambda: rmsnorm_flashinfer(
x.clone(), x.clone(),
...@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): ...@@ -250,6 +318,8 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
) )
) )
elif provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
return timed( return timed(
lambda: rmsnorm_vllm( lambda: rmsnorm_vllm(
x.clone(), x.clone(),
...@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): ...@@ -267,13 +337,22 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
) )
# provider == "speedup" # provider == "speedup"
t_ref, _, _ = timed( if VLLM_AVAILABLE:
lambda: rmsnorm_vllm( t_ref, _, _ = timed(
x.clone(), lambda: rmsnorm_vllm(
weight, x.clone(),
residual.clone() if residual is not None else None, weight,
residual.clone() if residual is not None else None,
)
)
else:
t_ref, _, _ = timed(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
) )
)
t_sgl, _, _ = timed( t_sgl, _, _ = timed(
lambda: rmsnorm_sglang( lambda: rmsnorm_sglang(
x.clone(), x.clone(),
...@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual): ...@@ -281,7 +360,7 @@ def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
residual.clone() if residual is not None else None, residual.clone() if residual is not None else None,
) )
) )
spd = t_ref / t_sgl spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd) return (spd, spd, spd)
......
import itertools import itertools
import os
import torch import torch
import triton import triton
...@@ -12,17 +13,31 @@ from sgl_kernel.testing.rotary_embedding import ( ...@@ -12,17 +13,31 @@ from sgl_kernel.testing.rotary_embedding import (
from sglang.srt.bench_utils import bench_kineto from sglang.srt.bench_utils import bench_kineto
configs = [ # CI environment detection
(batch_size, seq_len, save_kv_cache) IS_CI = (
for batch_size, seq_len in ( os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# CI environment uses simplified parameters
if IS_CI:
batch_seq_configs = [(1, 1)] # Single config for CI
save_kv_configs = [False] # Single option for CI
else:
batch_seq_configs = [
(1, 1), (1, 1),
(32, 1), (32, 1),
(128, 1), (128, 1),
(512, 1), (512, 1),
(2, 512), (2, 512),
(4, 4096), (4, 4096),
) ]
for save_kv_cache in (False, True) save_kv_configs = [False, True]
configs = [
(batch_size, seq_len, save_kv_cache)
for batch_size, seq_len in batch_seq_configs
for save_kv_cache in save_kv_configs
] ]
......
import itertools import itertools
import os
import sgl_kernel import sgl_kernel
import torch import torch
import triton import triton
import triton.testing import triton.testing
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def torch_top_k_top_p_joint_sampling_from_probs( def torch_top_k_top_p_joint_sampling_from_probs(
normalized_prob, top_k, top_p, eps=1e-4 normalized_prob, top_k, top_p, eps=1e-4
...@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p): ...@@ -67,10 +74,16 @@ def calculate_diff(batch_size, vocab_size, p):
) )
# parameter space # parameter space - simplified for CI
batch_size_range = [16, 64, 128] if IS_CI:
vocab_size_range = [111, 32000] batch_size_range = [16] # Single batch size for CI
p_range = [0.1, 0.5] vocab_size_range = [111] # Single vocab size for CI
p_range = [0.1] # Single p value for CI
else:
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range)) configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))
...@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider): ...@@ -114,15 +127,19 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order="joint", filter_apply_order="joint",
) )
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__": if __name__ == "__main__":
# Correctness check # Correctness check - simplified for CI
for cfg in configs: if IS_CI:
# Only test one configuration in CI
test_configs = [configs[0]] if configs else [(16, 111, 0.1)]
else:
test_configs = configs
for cfg in test_configs:
calculate_diff(*cfg) calculate_diff(*cfg)
print("\n" + "=" * 60) print("\n" + "=" * 60)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment