Unverified Commit 11965b0d authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix sgl-kernel benchmark dead code (#11022)

parent 71959545
......@@ -155,6 +155,50 @@ jobs:
cd test/srt
python3 test_mla_deepseek_v3.py
sgl-kernel-benchmark-test:
needs: [check-changes, sgl-kernel-build-wheels]
if: always() && !failure() && !cancelled()
runs-on: 1-gpu-runner
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
CI: true
steps:
- uses: actions/checkout@v4
- name: Cleanup
run: |
ls -alh sgl-kernel/dist || true
rm -rf sgl-kernel/dist/* || true
- name: Download artifacts
uses: actions/download-artifact@v4
with:
path: sgl-kernel/dist/
merge-multiple: true
pattern: wheel-python3.10-cuda12.9
- name: Install dependencies
run: |
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh
- name: Run benchmark tests
timeout-minutes: 45
run: |
cd sgl-kernel/benchmark
echo "Running sgl-kernel benchmark tests in CI mode..."
echo "CI environment variable: $CI"
echo "GITHUB_ACTIONS environment variable: $GITHUB_ACTIONS"
for bench_file in bench_*.py; do
echo "Testing $bench_file..."
timeout 60 python3 "$bench_file" || echo "Warning: $bench_file timed out or failed, continuing..."
echo "Completed $bench_file"
echo "---"
done
echo "All benchmark tests completed!"
# =============================================== primary ====================================================
unit-test-frontend:
......@@ -647,7 +691,7 @@ jobs:
check-changes,
sgl-kernel-build-wheels,
sgl-kernel-unit-test, sgl-kernel-mla-test,
sgl-kernel-unit-test, sgl-kernel-mla-test, sgl-kernel-benchmark-test,
unit-test-frontend, unit-test-backend-1-gpu,
unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu,
......
......@@ -2460,7 +2460,7 @@ class BumpAllocator:
def log_info_on_rank0(logger, msg):
from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0:
if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
logger.info(msg)
......
......@@ -2,6 +2,7 @@
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import os
import re
from typing import List, Tuple
......@@ -11,7 +12,21 @@ import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
# Optional vLLM import
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# gelu_quick is only available on HIP/ROCm platforms
try:
......@@ -22,7 +37,7 @@ except ImportError:
GELU_QUICK_AVAILABLE = False
gelu_quick = None
if not hasattr(vllm_ops, "silu_and_mul"):
if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
......@@ -40,6 +55,13 @@ def calculate_diff(
"""Compare vLLM with SGLang for one shape."""
device = torch.device("cuda")
if not VLLM_AVAILABLE:
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] ⚠️ vLLM not available, skipping comparison"
)
return True
# activation-only quick GELU
if kernel == "gelu_quick":
if not GELU_QUICK_AVAILABLE:
......@@ -68,19 +90,30 @@ def calculate_diff(
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16]
# CI environment uses simplified parameters for kernels and dtypes too
if IS_CI:
kernels = ["silu_and_mul"] # Only test one kernel in CI
dtypes = [torch.float16] # Only test one dtype in CI
else:
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(10, 15)] # 1024...16384
# CI environment uses simplified parameters
if IS_CI:
default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [1] # Single sequence length for CI
default_dims = [1024] # Single dimension for CI
else:
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(10, 15)] # 1024...16384
@triton.testing.perf_report(
......@@ -102,16 +135,24 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]:
# Skip vLLM-related benchmarks if vLLM is not available
return (0, 0, 0)
if VLLM_AVAILABLE:
vllm_kernel = getattr(vllm_ops, kernel)
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
# Skip benchmark for gelu_quick if not available
return (0, 0, 0)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
if VLLM_AVAILABLE:
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
else:
return torch.zeros_like(y0)
def sglang():
return sglang_kernel(x)
......@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
# provider == "speedup"
t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl
spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd)
......
import itertools
import os
from typing import List, Tuple
import torch
import triton
import triton.testing
from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops
# Optional vLLM import
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def vllm_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_awq_dequantize(qweight, scales, qzeros)
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
......@@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int):
device=device,
)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
......@@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int):
print("❌ Implementations differ")
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
# CI environment uses simplified parameters
if IS_CI:
qweight_row_range = [128] # Single row size for CI
qweight_cols_range = [16] # Single column size for CI
else:
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
configs = list(itertools.product(qweight_row_range, qweight_cols_range))
......@@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range))
x_names=["qweight_row", "qweight_col"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "sglang"],
line_names=["VLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")],
line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"],
line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"],
styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")],
ylabel="us",
plot_name="awq-dequantize-performance",
args={},
......@@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone()
)
......@@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider):
if __name__ == "__main__":
calculate_diff(qweight_row=3584, qweight_col=448)
# Simplify for CI environment
if IS_CI:
qweight_row, qweight_col = 128, 16 # Smaller values for CI
else:
qweight_row, qweight_col = 3584, 448
calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col)
benchmark.run(print_data=True)
import argparse
import copy
import itertools
import os
import torch
import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# CI environment uses simplified parameters
if IS_CI:
bs_range = [1] # Single batch size for CI
qlen_range = [64] # Single sequence length for CI
else:
bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(bs_range, qlen_range))
......@@ -131,13 +145,34 @@ if __name__ == "__main__":
)
args = parser.parse_args()
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")
# Skip in CI environment or unsupported architectures
if IS_CI:
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+
print("Skipping Cutlass MLA benchmark in CI environment")
if major is not None:
print(
f"Cutlass MLA requires compute capability 10.0+, but found {major}.{minor}"
)
else:
print("Could not determine device capability")
else:
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")
else:
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")
import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import torch
import torch.nn.functional as F
......@@ -6,16 +13,28 @@ import triton
import triton.testing
from sgl_kernel import dsv3_fused_a_gemm
# CI environment uses simplified parameters
if IS_CI:
num_tokens_vals = [1] # Only test 1 value in CI
line_vals = ["sgl-kernel"] # Only test sgl-kernel implementation in CI
else:
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
line_vals = ["torch", "sgl-kernel"]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_vals=num_tokens_vals,
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
styles=[("blue", "-"), ("orange", "-")],
line_vals=line_vals,
line_names=(
["torch (bf16)", "dsv3_fused_a_gemm"]
if not IS_CI
else ["dsv3_fused_a_gemm"]
),
styles=[("blue", "-"), ("orange", "-")] if not IS_CI else [("orange", "-")],
ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput",
args={},
......
import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import torch
import torch.nn.functional as F
......@@ -6,21 +13,37 @@ import triton
import triton.testing
from sgl_kernel import dsv3_router_gemm
# CI environment uses simplified parameters
if IS_CI:
num_tokens_vals = [1] # Only test 1 value in CI
line_vals = ["sgl-kernel-256"] # Only test one implementation in CI
else:
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_vals=num_tokens_vals,
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
line_vals=line_vals,
line_names=(
[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
]
if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={},
......@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_vals=num_tokens_vals,
x_log=False,
line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"],
line_names=[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
],
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")],
line_vals=line_vals,
line_names=(
[
"torch-256",
"dsv3_router_gemm-256",
"torch-384",
"dsv3_router_gemm-384",
]
if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={},
......
......@@ -2,6 +2,7 @@ import argparse
import copy
import csv
import itertools
import os
import pytest
import torch
......@@ -9,6 +10,14 @@ import triton
from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
......@@ -33,27 +42,34 @@ def get_weight_shapes(args):
]
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 8] # Simplified for CI
else:
batch_sizes = [
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
],
x_vals=batch_sizes,
# x_vals = [64],
x_log=False,
line_arg="provider",
......@@ -188,21 +204,38 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
if args.csv:
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"])
NKs = get_weight_shapes(args)
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
print("Benchmark finished!")
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
print("Skipping FP4 GEMM benchmark")
if major is not None:
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
else:
print("Could not determine device capability")
else:
NKs = get_weight_shapes(args)
# Limit iterations in CI
if IS_CI:
NKs = NKs[:2] # Only test first 2 shapes in CI
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
print("Benchmark finished!")
import argparse
import copy
import itertools
import os
import deep_gemm
import torch
import triton
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
# Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
)
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes))
......@@ -80,15 +95,46 @@ def scale_shape(shape, group_shape):
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 8] # Simplified for CI
else:
batch_sizes = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
# Filter providers based on availability
available_providers = ["sgl-kernel"]
available_names = ["sgl-kernel"]
available_styles = [("orange", "-")]
if VLLM_AVAILABLE:
available_providers.insert(0, "vllm")
available_names.insert(0, "vllm")
available_styles.insert(0, ("blue", "-"))
available_providers.append("triton")
available_names.append("sglang triton")
available_styles.append(("red", "-"))
# Add deepgemm if available
try:
import deep_gemm
available_providers.append("deepgemm")
available_names.append("deepgemm")
available_styles.append(("yellow", "-"))
except ImportError:
pass
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
x_vals=batch_sizes,
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
line_vals=available_providers,
line_names=available_names,
styles=available_styles,
ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul",
args={},
......@@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K):
),
quantiles=quantiles,
)
if provider == "vllm":
elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "triton":
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
......@@ -166,7 +214,17 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.models = [args.models[0]] # Use only first model
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
NK_model_names = get_weight_shapes(args)
# Limit iterations in CI
if IS_CI:
NK_model_names = NK_model_names[:2] # Only test first 2 shapes in CI
for N, K, model_name in NK_model_names:
if N % 128 != 0 or K % 128 != 0:
print(f"Skip {N=}, {K=} now")
......
import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import random
from dataclasses import dataclass
from typing import List, Tuple
......@@ -290,36 +297,44 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--num-warmup", type=int, default=3)
parser.add_argument("--num-run", type=int, default=10)
shape_args = [
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
]
# CI environment uses simplified parameters
if IS_CI:
shape_args = [
# Only test one simple shape in CI
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
]
else:
shape_args = [
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 8
ShapeArg(expected_m_per_group=256, n=512, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, TP = 16
ShapeArg(expected_m_per_group=256, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, TP = 16
ShapeArg(expected_m_per_group=512, n=256, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 32, TP = 8
ShapeArg(expected_m_per_group=1, n=512, k=7168, num_groups=256),
# Decode, DeepSeek-R1, gateup, bs = 64, TP = 16
ShapeArg(expected_m_per_group=2, n=256, k=7168, num_groups=256),
# Prefill, DeepSeek-R1, gateup, chunk_size = 8192, EP = 8
ShapeArg(expected_m_per_group=256, n=4096, k=7168, num_groups=32),
# Prefill, DeepSeek-R1, gateup, chunk_size = 16384, EP = 16
ShapeArg(expected_m_per_group=512, n=4096, k=7168, num_groups=16),
# Decode, DeepSeek-R1, gateup, bs = 128, EP = 8
ShapeArg(expected_m_per_group=4, n=4096, k=7168, num_groups=32),
# Decode, DeepSeek-R1, gateup, bs = 256, EP = 16
ShapeArg(expected_m_per_group=8, n=4096, k=7168, num_groups=16),
# Prefill, Qwen3-235B-A22B-FP8, gateup, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=768, k=4096, num_groups=128),
# Prefill, Qwen3-235B-A22B-FP8, down, chunk_size = 16384, TP = 4
ShapeArg(expected_m_per_group=1024, n=4096, k=384, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, gateup, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=768, k=4096, num_groups=128),
# Decode, Qwen3-235B-A22B-FP8, down, bs = 256, TP = 4
ShapeArg(expected_m_per_group=16, n=4096, k=384, num_groups=128),
]
args = parser.parse_args()
benchmark_one_shape(shape_args, args.num_warmup, args.num_run)
......
import argparse
import copy
import itertools
import os
from typing import Optional, Tuple
import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
# Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
vllm_scaled_fp8_quant = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
......@@ -86,25 +102,48 @@ def sglang_scaled_fp8_quant(
return output, scale
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1] # Single batch size for CI
else:
batch_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048]
# Filter line_vals based on vLLM availability
if VLLM_AVAILABLE:
line_vals = [
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
line_names = [
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
styles = [("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")]
else:
line_vals = [
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
line_names = [
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
styles = [("blue", "-"), ("blue", "--")]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048],
x_vals=batch_sizes,
x_log=False,
line_arg="provider",
line_vals=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
line_names=[
"vllm-fp8-fp16",
"vllm-fp8-bf16",
"sglang-fp8-fp16",
"sglang-fp8-bf16",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
line_vals=line_vals,
line_names=line_names,
styles=styles,
ylabel="GB/s",
plot_name="fp8 scaled matmul",
args={},
......@@ -115,6 +154,9 @@ def benchmark(batch_size, provider, N, K):
M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0
# vLLM expects scalar scales, while sglang can handle per-token scales
scale_a_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
scale_b_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
......@@ -122,8 +164,11 @@ def benchmark(batch_size, provider, N, K):
dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
if not VLLM_AVAILABLE:
# Return zero if vLLM is not available
return (0, 0, 0)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_scalar)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b_scalar)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
......@@ -174,6 +219,11 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.models = [args.models[0]] # Use only first model
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
......
import argparse
import copy
import itertools
import os
import torch
import triton
from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
# Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
......@@ -62,15 +77,32 @@ WEIGHT_SHAPES = {
}
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1] # Single batch size for CI
else:
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
line_vals = ["vllm", "sgl-kernel"]
line_names = ["vllm int8 gemm", "sgl-kernel int8 gemm"]
styles = [("blue", "-"), ("orange", "-")]
else:
line_vals = ["sgl-kernel"]
line_names = ["sgl-kernel int8 gemm"]
styles = [("orange", "-")]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_vals=batch_sizes,
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel"],
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"],
styles=[("blue", "-"), ("orange", "-")],
line_vals=line_vals,
line_names=line_names,
styles=styles,
ylabel="GB/s",
plot_name="int8 scaled matmul",
args={},
......@@ -90,7 +122,9 @@ def benchmark(batch_size, provider, N, K):
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
if provider == "vllm":
elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
......@@ -136,9 +170,16 @@ if __name__ == "__main__":
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
# Skip in CI environment due to architecture compatibility issues
if IS_CI:
print(
"Skipping INT8 GEMM benchmark in CI environment due to architecture compatibility issues"
)
print("INT8 operations may not be supported on all GPU architectures")
else:
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")
print("Benchmark finished!")
import itertools
import math
import os
import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
......@@ -207,7 +214,12 @@ def calculate_diff(batch_size):
print("❌ Implementations differ")
batch_size_range = [i for i in range(1, 65)] # 1 to 128
# Simplified for CI environment
if IS_CI:
batch_size_range = [1] # Single batch size for CI
else:
batch_size_range = [i for i in range(1, 65)] # 1 to 64
configs = [(bs,) for bs in batch_size_range]
......@@ -292,8 +304,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# Run correctness test
calculate_diff(batch_size=4)
# Run correctness test - simplified for CI
test_batch_size = 1 if IS_CI else 4
calculate_diff(batch_size=test_batch_size)
# Run performance benchmark
benchmark.run(print_data=True)
import argparse
import itertools
import os
import torch
import triton
......@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
USE_RANDOM_PERM = False
......@@ -197,19 +207,23 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
num_tokens_post_pad_triton,
)
try:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
if VLLM_AVAILABLE:
try:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
else:
print("⚠️ vLLM not available, skipping vLLM test")
vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
......@@ -394,8 +408,18 @@ if __name__ == "__main__":
)
args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
# Simplify for CI environment
if IS_CI:
num_tokens = 256 # Smaller for CI
num_experts = 8 # Smaller for CI
topk = 2 # Smaller for CI
else:
num_tokens = 1024
num_experts = args.num_experts
topk = args.topk
calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk)
if not args.skip_full_benchmark:
if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True)
import os
import torch
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import triton
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [64, 128] # Only test 2 values in CI
else:
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes]
......
import itertools
import math
import os
import torch
import triton
......@@ -8,6 +9,12 @@ from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
return biased_grouped_topk(
......@@ -28,7 +35,12 @@ def biased_grouped_topk_org_fuse_kernel(
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
# CI environment uses simplified parameters
if IS_CI:
seq_length_range = [5000] # Only test one sequence length in CI
else:
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
configs = [(sq,) for sq in seq_length_range]
......
import itertools
import os
import pytest
import torch
import triton
from sgl_kernel import topk_softmax
from vllm import _custom_ops as vllm_custom_ops
# Optional vLLM import
try:
from vllm import _custom_ops as vllm_custom_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_custom_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def vllm_topk_softmax(gating_output, topk):
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation if vLLM is not available
return sglang_topk_softmax(gating_output, topk)
num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty(
......@@ -54,6 +73,10 @@ def calculate_diff(num_tokens, num_experts, topk):
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
indices_match = torch.equal(indices_vllm, indices_sglang)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
if (
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
and indices_match
......@@ -65,21 +88,38 @@ def calculate_diff(num_tokens, num_experts, topk):
)
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
# CI environment uses simplified parameters
if IS_CI:
num_tokens_range = [128] # Single value for CI
num_experts_range = [32] # Single value for CI
topk_range = [2] # Single value for CI
else:
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
line_vals = ["sglang", "vllm"]
line_names = ["SGLang", "VLLM"]
styles = [("blue", "-"), ("green", "-")]
else:
line_vals = ["sglang"]
line_names = ["SGLang"]
styles = [("blue", "-")]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sglang", "vllm"],
line_names=["SGLang", "VLLM"],
styles=[("blue", "-"), ("green", "-")],
line_vals=line_vals,
line_names=line_names,
styles=styles,
ylabel="Latency (us)",
plot_name="topk-softmax-performance",
args={},
......@@ -92,6 +132,8 @@ def benchmark(num_tokens, num_experts, topk, provider):
)
if provider == "vllm" or provider == "vllm1":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk)
......@@ -103,14 +145,19 @@ def benchmark(num_tokens, num_experts, topk, provider):
if __name__ == "__main__":
configs = [
(20, 256, 4),
(20, 256, 8),
(20, 12, 4),
(20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in configs:
# Simplify configs for CI environment
if IS_CI:
test_configs = [(20, 32, 2)] # Single config for CI
else:
test_configs = [
(20, 256, 4),
(20, 256, 8),
(20, 12, 4),
(20, 12, 1),
(20, 512, 4),
(20, 512, 1),
]
for num_tokens, num_experts, topk in test_configs:
calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True)
import argparse
import copy
import itertools
import os
import torch
import triton
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
......@@ -162,9 +171,22 @@ if __name__ == "__main__":
)
args = parser.parse_args()
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
print("Skipping NVIDIA FP4 scaled GEMM benchmark")
if major is not None:
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
else:
print("Could not determine device capability")
else:
KN_model_names = prepare_shapes(args)
# Limit iterations in CI
if IS_CI:
KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI
print("Benchmark finished!")
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")
import itertools
import math
import os
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
......@@ -7,11 +8,26 @@ import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm import _custom_ops as ops
# Optional imports
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
from sglang.srt.utils import is_hip
_is_hip = is_hip()
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
......@@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_scaled_fp8_quant(input, scale)
return ops.scaled_fp8_quant(input, scale)
......@@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int):
device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
......@@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int):
print("❌ Implementations differ")
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
# CI environment uses simplified parameters
if IS_CI:
batch_size_range = [16] # Single batch size for CI
seq_len_range = [64] # Single sequence length for CI
else:
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
configs = list(itertools.product(batch_size_range, seq_len_range))
......
import itertools
import os
import time
from functools import partial
from pathlib import Path
......@@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.utils import is_hip
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
group_size_range = [128] # For DeepSeek V3/R1
# TODO test int8
dst_dtype_range = [fp8_type_]
# CI environment uses simplified parameters
if IS_CI:
num_tokens_range = [64] # Single value for CI
hidden_dim_range = [1536] # Single value for CI
group_size_range = [128] # Keep as is
dst_dtype_range = [fp8_type_] # Keep as is
else:
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
group_size_range = [128] # For DeepSeek V3/R1
# TODO test int8
dst_dtype_range = [fp8_type_]
flags_range = [
dict(
column_major_scales=False,
......@@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"),
"sglang": (
sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment