"src/vscode:/vscode.git/clone" did not exist on "0feb21a18c44cfbf76a916afead986f04b339292"
Unverified Commit 11965b0d authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Fix sgl-kernel benchmark dead code (#11022)

parent 71959545
...@@ -155,6 +155,50 @@ jobs: ...@@ -155,6 +155,50 @@ jobs:
cd test/srt cd test/srt
python3 test_mla_deepseek_v3.py python3 test_mla_deepseek_v3.py
sgl-kernel-benchmark-test:
needs: [check-changes, sgl-kernel-build-wheels]
if: always() && !failure() && !cancelled()
runs-on: 1-gpu-runner
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
CI: true
steps:
- uses: actions/checkout@v4
- name: Cleanup
run: |
ls -alh sgl-kernel/dist || true
rm -rf sgl-kernel/dist/* || true
- name: Download artifacts
uses: actions/download-artifact@v4
with:
path: sgl-kernel/dist/
merge-multiple: true
pattern: wheel-python3.10-cuda12.9
- name: Install dependencies
run: |
CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/ci_install_dependency.sh
- name: Run benchmark tests
timeout-minutes: 45
run: |
cd sgl-kernel/benchmark
echo "Running sgl-kernel benchmark tests in CI mode..."
echo "CI environment variable: $CI"
echo "GITHUB_ACTIONS environment variable: $GITHUB_ACTIONS"
for bench_file in bench_*.py; do
echo "Testing $bench_file..."
timeout 60 python3 "$bench_file" || echo "Warning: $bench_file timed out or failed, continuing..."
echo "Completed $bench_file"
echo "---"
done
echo "All benchmark tests completed!"
# =============================================== primary ==================================================== # =============================================== primary ====================================================
unit-test-frontend: unit-test-frontend:
...@@ -647,7 +691,7 @@ jobs: ...@@ -647,7 +691,7 @@ jobs:
check-changes, check-changes,
sgl-kernel-build-wheels, sgl-kernel-build-wheels,
sgl-kernel-unit-test, sgl-kernel-mla-test, sgl-kernel-unit-test, sgl-kernel-mla-test, sgl-kernel-benchmark-test,
unit-test-frontend, unit-test-backend-1-gpu, unit-test-frontend, unit-test-backend-1-gpu,
unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu, unit-test-backend-2-gpu, unit-test-backend-4-gpu, unit-test-backend-8-gpu,
......
...@@ -2460,7 +2460,7 @@ class BumpAllocator: ...@@ -2460,7 +2460,7 @@ class BumpAllocator:
def log_info_on_rank0(logger, msg): def log_info_on_rank0(logger, msg):
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0: if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0:
logger.info(msg) logger.info(msg)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up. # (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse import argparse
import itertools import itertools
import os
import re import re
from typing import List, Tuple from typing import List, Tuple
...@@ -11,7 +12,21 @@ import torch.nn.functional as F ...@@ -11,7 +12,21 @@ import torch.nn.functional as F
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
# Optional vLLM import
try:
from vllm import _custom_ops as vllm_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# gelu_quick is only available on HIP/ROCm platforms # gelu_quick is only available on HIP/ROCm platforms
try: try:
...@@ -22,7 +37,7 @@ except ImportError: ...@@ -22,7 +37,7 @@ except ImportError:
GELU_QUICK_AVAILABLE = False GELU_QUICK_AVAILABLE = False
gelu_quick = None gelu_quick = None
if not hasattr(vllm_ops, "silu_and_mul"): if VLLM_AVAILABLE and not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C vllm_ops = torch.ops._C
...@@ -40,6 +55,13 @@ def calculate_diff( ...@@ -40,6 +55,13 @@ def calculate_diff(
"""Compare vLLM with SGLang for one shape.""" """Compare vLLM with SGLang for one shape."""
device = torch.device("cuda") device = torch.device("cuda")
if not VLLM_AVAILABLE:
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] ⚠️ vLLM not available, skipping comparison"
)
return True
# activation-only quick GELU # activation-only quick GELU
if kernel == "gelu_quick": if kernel == "gelu_quick":
if not GELU_QUICK_AVAILABLE: if not GELU_QUICK_AVAILABLE:
...@@ -68,19 +90,30 @@ def calculate_diff( ...@@ -68,19 +90,30 @@ def calculate_diff(
return ok return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"] # CI environment uses simplified parameters for kernels and dtypes too
if GELU_QUICK_AVAILABLE: if IS_CI:
kernels = ["silu_and_mul"] # Only test one kernel in CI
dtypes = [torch.float16] # Only test one dtype in CI
else:
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick") kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16] dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]: def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_)) return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 # CI environment uses simplified parameters
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 if IS_CI:
default_dims = [2**i for i in range(10, 15)] # 1024...16384 default_batch_sizes = [1] # Single batch size for CI
default_seq_lens = [1] # Single sequence length for CI
default_dims = [1024] # Single dimension for CI
else:
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(10, 15)] # 1024...16384
@triton.testing.perf_report( @triton.testing.perf_report(
...@@ -102,6 +135,11 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): ...@@ -102,6 +135,11 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device) x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
if not VLLM_AVAILABLE and provider in ["vllm", "speedup"]:
# Skip vLLM-related benchmarks if vLLM is not available
return (0, 0, 0)
if VLLM_AVAILABLE:
vllm_kernel = getattr(vllm_ops, kernel) vllm_kernel = getattr(vllm_ops, kernel)
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE: if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
# Skip benchmark for gelu_quick if not available # Skip benchmark for gelu_quick if not available
...@@ -109,9 +147,12 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): ...@@ -109,9 +147,12 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
sglang_kernel = getattr(sgl_kernel, kernel) sglang_kernel = getattr(sgl_kernel, kernel)
def baseline(): def baseline():
if VLLM_AVAILABLE:
tmp = y0.clone() tmp = y0.clone()
vllm_kernel(tmp, x) vllm_kernel(tmp, x)
return tmp return tmp
else:
return torch.zeros_like(y0)
def sglang(): def sglang():
return sglang_kernel(x) return sglang_kernel(x)
...@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): ...@@ -134,7 +175,7 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
# provider == "speedup" # provider == "speedup"
t_ref, _, _ = timed(baseline) t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang) t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl spd = t_ref / t_sgl if t_ref > 0 else 1.0
return (spd, spd, spd) return (spd, spd, spd)
......
import itertools import itertools
import os
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import awq_dequantize from sgl_kernel import awq_dequantize
from vllm import _custom_ops as ops
# Optional vLLM import
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def vllm_awq_dequantize( def vllm_awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_awq_dequantize(qweight, scales, qzeros)
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
...@@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int): ...@@ -43,6 +61,10 @@ def calculate_diff(qweight_row: int, qweight_col: int):
device=device, device=device,
) )
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
vllm_out = vllm_awq_dequantize(qweight, scales, qzeros) vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
...@@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int): ...@@ -56,8 +78,13 @@ def calculate_diff(qweight_row: int, qweight_col: int):
print("❌ Implementations differ") print("❌ Implementations differ")
qweight_row_range = [3584, 18944, 128, 256, 512, 1024] # CI environment uses simplified parameters
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128] if IS_CI:
qweight_row_range = [128] # Single row size for CI
qweight_cols_range = [16] # Single column size for CI
else:
qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
configs = list(itertools.product(qweight_row_range, qweight_cols_range)) configs = list(itertools.product(qweight_row_range, qweight_cols_range))
...@@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range)) ...@@ -67,9 +94,9 @@ configs = list(itertools.product(qweight_row_range, qweight_cols_range))
x_names=["qweight_row", "qweight_col"], x_names=["qweight_row", "qweight_col"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sglang"], line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"],
line_names=["VLLM", "SGL Kernel"], line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")],
ylabel="us", ylabel="us",
plot_name="awq-dequantize-performance", plot_name="awq-dequantize-performance",
args={}, args={},
...@@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider): ...@@ -100,6 +127,8 @@ def benchmark(qweight_row, qweight_col, provider):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "vllm": if provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_awq_dequantize( fn = lambda: vllm_awq_dequantize(
qweight.clone(), scales.clone(), qzeros.clone() qweight.clone(), scales.clone(), qzeros.clone()
) )
...@@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider): ...@@ -114,5 +143,11 @@ def benchmark(qweight_row, qweight_col, provider):
if __name__ == "__main__": if __name__ == "__main__":
calculate_diff(qweight_row=3584, qweight_col=448) # Simplify for CI environment
if IS_CI:
qweight_row, qweight_col = 128, 16 # Smaller values for CI
else:
qweight_row, qweight_col = 3584, 448
calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col)
benchmark.run(print_data=True) benchmark.run(print_data=True)
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
bs_range = [1, 8, 32, 64, 128, 256] from sglang.srt.utils import get_device_capability
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# CI environment uses simplified parameters
if IS_CI:
bs_range = [1] # Single batch size for CI
qlen_range = [64] # Single sequence length for CI
else:
bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
configs = list(itertools.product(bs_range, qlen_range)) configs = list(itertools.product(bs_range, qlen_range))
...@@ -131,6 +145,28 @@ if __name__ == "__main__": ...@@ -131,6 +145,28 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Skip in CI environment or unsupported architectures
if IS_CI:
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+
print("Skipping Cutlass MLA benchmark in CI environment")
if major is not None:
print(
f"Cutlass MLA requires compute capability 10.0+, but found {major}.{minor}"
)
else:
print("Could not determine device capability")
else:
for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
block_size=block_size,
num_kv_splits=kv_split,
)
print("Benchmark finished!")
else:
for block_size in args.block_sizes: for block_size in args.block_sizes:
for kv_split in args.num_kv_splits: for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ") print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
...@@ -139,5 +175,4 @@ if __name__ == "__main__": ...@@ -139,5 +175,4 @@ if __name__ == "__main__":
block_size=block_size, block_size=block_size,
num_kv_splits=kv_split, num_kv_splits=kv_split,
) )
print("Benchmark finished!") print("Benchmark finished!")
import argparse import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -6,16 +13,28 @@ import triton ...@@ -6,16 +13,28 @@ import triton
import triton.testing import triton.testing
from sgl_kernel import dsv3_fused_a_gemm from sgl_kernel import dsv3_fused_a_gemm
# CI environment uses simplified parameters
if IS_CI:
num_tokens_vals = [1] # Only test 1 value in CI
line_vals = ["sgl-kernel"] # Only test sgl-kernel implementation in CI
else:
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
line_vals = ["torch", "sgl-kernel"]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens"], x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)], x_vals=num_tokens_vals,
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch", "sgl-kernel"], line_vals=line_vals,
line_names=["torch (bf16)", "dsv3_fused_a_gemm"], line_names=(
styles=[("blue", "-"), ("orange", "-")], ["torch (bf16)", "dsv3_fused_a_gemm"]
if not IS_CI
else ["dsv3_fused_a_gemm"]
),
styles=[("blue", "-"), ("orange", "-")] if not IS_CI else [("orange", "-")],
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput", plot_name="bf16 dsv3 fused a GEMM throughput",
args={}, args={},
......
import argparse import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -6,21 +13,37 @@ import triton ...@@ -6,21 +13,37 @@ import triton
import triton.testing import triton.testing
from sgl_kernel import dsv3_router_gemm from sgl_kernel import dsv3_router_gemm
# CI environment uses simplified parameters
if IS_CI:
num_tokens_vals = [1] # Only test 1 value in CI
line_vals = ["sgl-kernel-256"] # Only test one implementation in CI
else:
num_tokens_vals = [i + 1 for i in range(16)] # Test 1-16 in full mode
line_vals = ["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens"], x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)], x_vals=num_tokens_vals,
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], line_vals=line_vals,
line_names=[ line_names=(
[
"torch-256", "torch-256",
"dsv3_router_gemm-256", "dsv3_router_gemm-256",
"torch-384", "torch-384",
"dsv3_router_gemm-384", "dsv3_router_gemm-384",
], ]
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="input-bf16-output-bf16 dsv3 router gemm throughput", plot_name="input-bf16-output-bf16 dsv3 router gemm throughput",
args={}, args={},
...@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl): ...@@ -64,17 +87,25 @@ def benchmark_bf16_output(num_tokens, impl):
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens"], x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)], x_vals=num_tokens_vals,
x_log=False, x_log=False,
line_arg="impl", line_arg="impl",
line_vals=["torch-256", "sgl-kernel-256", "torch-384", "sgl-kernel-384"], line_vals=line_vals,
line_names=[ line_names=(
[
"torch-256", "torch-256",
"dsv3_router_gemm-256", "dsv3_router_gemm-256",
"torch-384", "torch-384",
"dsv3_router_gemm-384", "dsv3_router_gemm-384",
], ]
styles=[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")], if not IS_CI
else ["dsv3_router_gemm-256"]
),
styles=(
[("blue", "-"), ("orange", "-"), ("green", "-"), ("red", "-")]
if not IS_CI
else [("orange", "-")]
),
ylabel="TFLOPs", ylabel="TFLOPs",
plot_name="input-bf16-output-fp32 dsv3 router gemm throughput", plot_name="input-bf16-output-fp32 dsv3 router gemm throughput",
args={}, args={},
......
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
import copy import copy
import csv import csv
import itertools import itertools
import os
import pytest import pytest
import torch import torch
...@@ -9,6 +10,14 @@ import triton ...@@ -9,6 +10,14 @@ import triton
from flashinfer import mm_fp4 from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0 FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
...@@ -33,10 +42,11 @@ def get_weight_shapes(args): ...@@ -33,10 +42,11 @@ def get_weight_shapes(args):
] ]
@triton.testing.perf_report( # CI environment uses simplified parameters
triton.testing.Benchmark( if IS_CI:
x_names=["batch_size"], batch_sizes = [1, 8] # Simplified for CI
x_vals=[ else:
batch_sizes = [
1, 1,
2, 2,
4, 4,
...@@ -53,7 +63,13 @@ def get_weight_shapes(args): ...@@ -53,7 +63,13 @@ def get_weight_shapes(args):
4096, 4096,
8192, 8192,
16384, 16384,
], ]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=batch_sizes,
# x_vals = [64], # x_vals = [64],
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
...@@ -188,12 +204,30 @@ if __name__ == "__main__": ...@@ -188,12 +204,30 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
if args.csv: if args.csv:
with open(args.csv, "w", newline="") as f: with open(args.csv, "w", newline="") as f:
writer = csv.writer(f) writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"]) writer.writerow(["provider", "m", "n", "k", "time_ms"])
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
print("Skipping FP4 GEMM benchmark")
if major is not None:
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
else:
print("Could not determine device capability")
else:
NKs = get_weight_shapes(args) NKs = get_weight_shapes(args)
# Limit iterations in CI
if IS_CI:
NKs = NKs[:2] # Only test first 2 shapes in CI
for N, K in NKs: for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run( benchmark.run(
...@@ -204,5 +238,4 @@ if __name__ == "__main__": ...@@ -204,5 +238,4 @@ if __name__ == "__main__":
correctness=args.correctness, correctness=args.correctness,
csv_file=args.csv, csv_file=args.csv,
) )
print("Benchmark finished!") print("Benchmark finished!")
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import deep_gemm import deep_gemm
import torch import torch
import triton import triton
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
# Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
VLLM_AVAILABLE = False
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul, w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
) )
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def get_weight_shapes(args): def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes)) models_tps = list(itertools.product(args.models, args.tp_sizes))
...@@ -80,15 +95,46 @@ def scale_shape(shape, group_shape): ...@@ -80,15 +95,46 @@ def scale_shape(shape, group_shape):
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1, 8] # Simplified for CI
else:
batch_sizes = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
# Filter providers based on availability
available_providers = ["sgl-kernel"]
available_names = ["sgl-kernel"]
available_styles = [("orange", "-")]
if VLLM_AVAILABLE:
available_providers.insert(0, "vllm")
available_names.insert(0, "vllm")
available_styles.insert(0, ("blue", "-"))
available_providers.append("triton")
available_names.append("sglang triton")
available_styles.append(("red", "-"))
# Add deepgemm if available
try:
import deep_gemm
available_providers.append("deepgemm")
available_names.append("deepgemm")
available_styles.append(("yellow", "-"))
except ImportError:
pass
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], x_vals=batch_sizes,
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"], line_vals=available_providers,
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"], line_names=available_names,
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")], styles=available_styles,
ylabel="GB/s", ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul", plot_name="fp8 blockwise scaled matmul",
args={}, args={},
...@@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K): ...@@ -123,14 +169,16 @@ def benchmark(batch_size, provider, N, K):
), ),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
scale_a = scale_a.t().contiguous().t() scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t() b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "triton": elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: w8a8_block_fp8_matmul( lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16 a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
...@@ -166,7 +214,17 @@ if __name__ == "__main__": ...@@ -166,7 +214,17 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.models = [args.models[0]] # Use only first model
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
NK_model_names = get_weight_shapes(args) NK_model_names = get_weight_shapes(args)
# Limit iterations in CI
if IS_CI:
NK_model_names = NK_model_names[:2] # Only test first 2 shapes in CI
for N, K, model_name in NK_model_names: for N, K, model_name in NK_model_names:
if N % 128 != 0 or K % 128 != 0: if N % 128 != 0 or K % 128 != 0:
print(f"Skip {N=}, {K=} now") print(f"Skip {N=}, {K=} now")
......
import argparse import argparse
import os
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple from typing import List, Tuple
...@@ -290,6 +297,14 @@ def main(): ...@@ -290,6 +297,14 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--num-warmup", type=int, default=3) parser.add_argument("--num-warmup", type=int, default=3)
parser.add_argument("--num-run", type=int, default=10) parser.add_argument("--num-run", type=int, default=10)
# CI environment uses simplified parameters
if IS_CI:
shape_args = [
# Only test one simple shape in CI
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
]
else:
shape_args = [ shape_args = [
# Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8 # Prefill, DeepSeek-R1, gateup, chunk_size = 4096, TP = 8
ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256), ShapeArg(expected_m_per_group=128, n=512, k=7168, num_groups=256),
......
import argparse import argparse
import copy import copy
import itertools import itertools
import os
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import triton import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8 from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant # Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
vllm_scaled_fp8_quant = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
# Weight Shapes are in the format # Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM) # ([K, N], TP_SPLIT_DIM)
...@@ -86,25 +102,48 @@ def sglang_scaled_fp8_quant( ...@@ -86,25 +102,48 @@ def sglang_scaled_fp8_quant(
return output, scale return output, scale
@triton.testing.perf_report( # CI environment uses simplified parameters
triton.testing.Benchmark( if IS_CI:
x_names=["batch_size"], batch_sizes = [1] # Single batch size for CI
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], else:
x_log=False, batch_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048]
line_arg="provider",
line_vals=[ # Filter line_vals based on vLLM availability
if VLLM_AVAILABLE:
line_vals = [
"vllm-fp8-fp16", "vllm-fp8-fp16",
"vllm-fp8-bf16", "vllm-fp8-bf16",
"sglang-fp8-fp16", "sglang-fp8-fp16",
"sglang-fp8-bf16", "sglang-fp8-bf16",
], ]
line_names=[ line_names = [
"vllm-fp8-fp16", "vllm-fp8-fp16",
"vllm-fp8-bf16", "vllm-fp8-bf16",
"sglang-fp8-fp16", "sglang-fp8-fp16",
"sglang-fp8-bf16", "sglang-fp8-bf16",
], ]
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], styles = [("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")]
else:
line_vals = [
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
line_names = [
"sglang-fp8-fp16",
"sglang-fp8-bf16",
]
styles = [("blue", "-"), ("blue", "--")]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=batch_sizes,
x_log=False,
line_arg="provider",
line_vals=line_vals,
line_names=line_names,
styles=styles,
ylabel="GB/s", ylabel="GB/s",
plot_name="fp8 scaled matmul", plot_name="fp8 scaled matmul",
args={}, args={},
...@@ -115,6 +154,9 @@ def benchmark(batch_size, provider, N, K): ...@@ -115,6 +154,9 @@ def benchmark(batch_size, provider, N, K):
M = batch_size M = batch_size
a = torch.ones((M, K), device="cuda") * 5.0 a = torch.ones((M, K), device="cuda") * 5.0
b = torch.ones((N, K), device="cuda") * 5.0 b = torch.ones((N, K), device="cuda") * 5.0
# vLLM expects scalar scales, while sglang can handle per-token scales
scale_a_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
scale_b_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
...@@ -122,8 +164,11 @@ def benchmark(batch_size, provider, N, K): ...@@ -122,8 +164,11 @@ def benchmark(batch_size, provider, N, K):
dtype = torch.float16 if "fp16" in provider else torch.bfloat16 dtype = torch.float16 if "fp16" in provider else torch.bfloat16
if "vllm-fp8" in provider: if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) if not VLLM_AVAILABLE:
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) # Return zero if vLLM is not available
return (0, 0, 0)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_scalar)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b_scalar)
b_fp8 = b_fp8.t() b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
...@@ -174,6 +219,11 @@ if __name__ == "__main__": ...@@ -174,6 +219,11 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Simplify for CI environment
if IS_CI:
args.models = [args.models[0]] # Use only first model
args.tp_sizes = [args.tp_sizes[0]] # Use only first TP size
KN_model_names = prepare_shapes(args) KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names: for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
......
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
# Optional vLLM import
try:
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
VLLM_AVAILABLE = True
except ImportError:
vllm_scaled_mm = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
...@@ -62,15 +77,32 @@ WEIGHT_SHAPES = { ...@@ -62,15 +77,32 @@ WEIGHT_SHAPES = {
} }
# CI environment uses simplified parameters
if IS_CI:
batch_sizes = [1] # Single batch size for CI
else:
batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]
# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
line_vals = ["vllm", "sgl-kernel"]
line_names = ["vllm int8 gemm", "sgl-kernel int8 gemm"]
styles = [("blue", "-"), ("orange", "-")]
else:
line_vals = ["sgl-kernel"]
line_names = ["sgl-kernel int8 gemm"]
styles = [("orange", "-")]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], x_vals=batch_sizes,
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=["vllm", "sgl-kernel"], line_vals=line_vals,
line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"], line_names=line_names,
styles=[("blue", "-"), ("orange", "-")], styles=styles,
ylabel="GB/s", ylabel="GB/s",
plot_name="int8 scaled matmul", plot_name="int8 scaled matmul",
args={}, args={},
...@@ -90,7 +122,9 @@ def benchmark(batch_size, provider, N, K): ...@@ -90,7 +122,9 @@ def benchmark(batch_size, provider, N, K):
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "vllm": elif provider == "vllm":
if not VLLM_AVAILABLE:
return (0, 0, 0)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles, quantiles=quantiles,
...@@ -136,6 +170,13 @@ if __name__ == "__main__": ...@@ -136,6 +170,13 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Skip in CI environment due to architecture compatibility issues
if IS_CI:
print(
"Skipping INT8 GEMM benchmark in CI environment due to architecture compatibility issues"
)
print("INT8 operations may not be supported on all GPU architectures")
else:
KN_model_names = prepare_shapes(args) KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names: for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
......
import itertools import itertools
import math import math
import os
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import lightning_attention_decode from sgl_kernel import lightning_attention_decode
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def next_power_of_2(n): def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2)))) return 2 ** (int(math.ceil(math.log(n, 2))))
...@@ -207,7 +214,12 @@ def calculate_diff(batch_size): ...@@ -207,7 +214,12 @@ def calculate_diff(batch_size):
print("❌ Implementations differ") print("❌ Implementations differ")
batch_size_range = [i for i in range(1, 65)] # 1 to 128 # Simplified for CI environment
if IS_CI:
batch_size_range = [1] # Single batch size for CI
else:
batch_size_range = [i for i in range(1, 65)] # 1 to 64
configs = [(bs,) for bs in batch_size_range] configs = [(bs,) for bs in batch_size_range]
...@@ -292,8 +304,9 @@ if __name__ == "__main__": ...@@ -292,8 +304,9 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Run correctness test # Run correctness test - simplified for CI
calculate_diff(batch_size=4) test_batch_size = 1 if IS_CI else 4
calculate_diff(batch_size=test_batch_size)
# Run performance benchmark # Run performance benchmark
benchmark.run(print_data=True) benchmark.run(print_data=True)
import argparse import argparse
import itertools import itertools
import os
import torch import torch
import triton import triton
...@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size ...@@ -8,8 +9,17 @@ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError: except ImportError:
ops = None ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
USE_RANDOM_PERM = False USE_RANDOM_PERM = False
...@@ -197,6 +207,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -197,6 +207,7 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
num_tokens_post_pad_triton, num_tokens_post_pad_triton,
) )
if VLLM_AVAILABLE:
try: try:
ops.moe_align_block_size( ops.moe_align_block_size(
topk_ids, topk_ids,
...@@ -211,6 +222,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -211,6 +222,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
except Exception as e: except Exception as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}") print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False vllm_works = False
else:
print("⚠️ vLLM not available, skipping vLLM test")
vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton num_tokens_post_pad_cuda, num_tokens_post_pad_triton
...@@ -394,8 +408,18 @@ if __name__ == "__main__": ...@@ -394,8 +408,18 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) # Simplify for CI environment
if IS_CI:
num_tokens = 256 # Smaller for CI
num_experts = 8 # Smaller for CI
topk = 2 # Smaller for CI
else:
num_tokens = 1024
num_experts = args.num_experts
topk = args.topk
calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk)
if not args.skip_full_benchmark: if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...") print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True) benchmark.run(print_data=True)
import os
import torch import torch
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
import triton import triton
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] # CI environment uses simplified parameters
if IS_CI:
batch_sizes = [64, 128] # Only test 2 values in CI
else:
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes] configs = [(bs,) for bs in batch_sizes]
......
import itertools import itertools
import math import math
import os
import torch import torch
import triton import triton
...@@ -8,6 +9,12 @@ from sgl_kernel import moe_fused_gate ...@@ -8,6 +9,12 @@ from sgl_kernel import moe_fused_gate
from sglang.srt.layers.moe.topk import biased_grouped_topk from sglang.srt.layers.moe.topk import biased_grouped_topk
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk): def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk):
return biased_grouped_topk( return biased_grouped_topk(
...@@ -28,7 +35,12 @@ def biased_grouped_topk_org_fuse_kernel( ...@@ -28,7 +35,12 @@ def biased_grouped_topk_org_fuse_kernel(
return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk) return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk)
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000] # CI environment uses simplified parameters
if IS_CI:
seq_length_range = [5000] # Only test one sequence length in CI
else:
seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000]
configs = [(sq,) for sq in seq_length_range] configs = [(sq,) for sq in seq_length_range]
......
import itertools import itertools
import os
import pytest import pytest
import torch import torch
import triton import triton
from sgl_kernel import topk_softmax from sgl_kernel import topk_softmax
from vllm import _custom_ops as vllm_custom_ops
# Optional vLLM import
try:
from vllm import _custom_ops as vllm_custom_ops
VLLM_AVAILABLE = True
except ImportError:
vllm_custom_ops = None
VLLM_AVAILABLE = False
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
def vllm_topk_softmax(gating_output, topk): def vllm_topk_softmax(gating_output, topk):
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation if vLLM is not available
return sglang_topk_softmax(gating_output, topk)
num_tokens, num_experts = gating_output.shape num_tokens, num_experts = gating_output.shape
topk_weights = torch.empty( topk_weights = torch.empty(
...@@ -54,6 +73,10 @@ def calculate_diff(num_tokens, num_experts, topk): ...@@ -54,6 +73,10 @@ def calculate_diff(num_tokens, num_experts, topk):
weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item() weights_diff = torch.abs(weights_vllm - weights_sglang).mean().item()
indices_match = torch.equal(indices_vllm, indices_sglang) indices_match = torch.equal(indices_vllm, indices_sglang)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
if ( if (
torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3) torch.allclose(weights_vllm, weights_sglang, atol=1e-3, rtol=1e-3)
and indices_match and indices_match
...@@ -65,21 +88,38 @@ def calculate_diff(num_tokens, num_experts, topk): ...@@ -65,21 +88,38 @@ def calculate_diff(num_tokens, num_experts, topk):
) )
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768] # CI environment uses simplified parameters
num_experts_range = [32, 64, 128, 256, 12, 512] if IS_CI:
topk_range = [1, 2, 4, 8] num_tokens_range = [128] # Single value for CI
num_experts_range = [32] # Single value for CI
topk_range = [2] # Single value for CI
else:
num_tokens_range = [128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
num_experts_range = [32, 64, 128, 256, 12, 512]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
line_vals = ["sglang", "vllm"]
line_names = ["SGLang", "VLLM"]
styles = [("blue", "-"), ("green", "-")]
else:
line_vals = ["sglang"]
line_names = ["SGLang"]
styles = [("blue", "-")]
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"], x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["sglang", "vllm"], line_vals=line_vals,
line_names=["SGLang", "VLLM"], line_names=line_names,
styles=[("blue", "-"), ("green", "-")], styles=styles,
ylabel="Latency (us)", ylabel="Latency (us)",
plot_name="topk-softmax-performance", plot_name="topk-softmax-performance",
args={}, args={},
...@@ -92,6 +132,8 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -92,6 +132,8 @@ def benchmark(num_tokens, num_experts, topk, provider):
) )
if provider == "vllm" or provider == "vllm1": if provider == "vllm" or provider == "vllm1":
if not VLLM_AVAILABLE:
return (0, 0, 0)
fn = lambda: vllm_topk_softmax(gating_output, topk) fn = lambda: vllm_topk_softmax(gating_output, topk)
elif provider == "sglang" or provider == "sglang1": elif provider == "sglang" or provider == "sglang1":
fn = lambda: sglang_topk_softmax(gating_output, topk) fn = lambda: sglang_topk_softmax(gating_output, topk)
...@@ -103,7 +145,11 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -103,7 +145,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
if __name__ == "__main__": if __name__ == "__main__":
configs = [ # Simplify configs for CI environment
if IS_CI:
test_configs = [(20, 32, 2)] # Single config for CI
else:
test_configs = [
(20, 256, 4), (20, 256, 4),
(20, 256, 8), (20, 256, 8),
(20, 12, 4), (20, 12, 4),
...@@ -111,6 +157,7 @@ if __name__ == "__main__": ...@@ -111,6 +157,7 @@ if __name__ == "__main__":
(20, 512, 4), (20, 512, 4),
(20, 512, 1), (20, 512, 1),
] ]
for num_tokens, num_experts, topk in configs:
for num_tokens, num_experts, topk in test_configs:
calculate_diff(num_tokens, num_experts, topk) calculate_diff(num_tokens, num_experts, topk)
benchmark.run(print_data=True) benchmark.run(print_data=True)
import argparse import argparse
import copy import copy
import itertools import itertools
import os
import torch import torch
import triton import triton
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
from sglang.srt.utils import get_device_capability
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
FLOAT4_E2M1_MAX = 6.0 FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
...@@ -162,9 +171,22 @@ if __name__ == "__main__": ...@@ -162,9 +171,22 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
# Check architecture compatibility - FP4 operations require sm100a/sm103a
major, minor = get_device_capability()
if major is None or major < 10: # Requires compute capability 10.0+ (sm100a/sm103a)
print("Skipping NVIDIA FP4 scaled GEMM benchmark")
if major is not None:
print(f"FP4 operations require sm100a/sm103a, but found sm{major}{minor}")
else:
print("Could not determine device capability")
else:
KN_model_names = prepare_shapes(args) KN_model_names = prepare_shapes(args)
# Limit iterations in CI
if IS_CI:
KN_model_names = KN_model_names[:2] # Only test first 2 shapes in CI
for K, N, model_name in KN_model_names: for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
benchmark.run(print_data=True, N=N, K=K) benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!") print("Benchmark finished!")
import itertools import itertools
import math import math
import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
...@@ -7,11 +8,26 @@ import torch ...@@ -7,11 +8,26 @@ import torch
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import sgl_per_tensor_quant_fp8 from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm import _custom_ops as ops
# Optional imports
try:
from vllm import _custom_ops as ops
VLLM_AVAILABLE = True
except ImportError:
ops = None
VLLM_AVAILABLE = False
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
...@@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant( ...@@ -19,6 +35,9 @@ def vllm_scaled_fp8_quant(
input: torch.Tensor, input: torch.Tensor,
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if not VLLM_AVAILABLE:
# Fallback to SGLang implementation
return sglang_scaled_fp8_quant(input, scale)
return ops.scaled_fp8_quant(input, scale) return ops.scaled_fp8_quant(input, scale)
...@@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int): ...@@ -42,6 +61,10 @@ def calculate_diff(batch_size: int, seq_len: int):
device = torch.device("cuda") device = torch.device("cuda")
x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device)
if not VLLM_AVAILABLE:
print("⚠️ vLLM not available, skipping comparison")
return
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
...@@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int): ...@@ -56,8 +79,13 @@ def calculate_diff(batch_size: int, seq_len: int):
print("❌ Implementations differ") print("❌ Implementations differ")
batch_size_range = [16, 32, 64, 128] # CI environment uses simplified parameters
seq_len_range = [64, 128, 256, 512, 1024, 2048] if IS_CI:
batch_size_range = [16] # Single batch size for CI
seq_len_range = [64] # Single sequence length for CI
else:
batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
configs = list(itertools.product(batch_size_range, seq_len_range)) configs = list(itertools.product(batch_size_range, seq_len_range))
......
import itertools import itertools
import os
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
...@@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -16,15 +17,28 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
# CI environment detection
IS_CI = (
os.getenv("CI", "false").lower() == "true"
or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] # CI environment uses simplified parameters
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 if IS_CI:
group_size_range = [128] # For DeepSeek V3/R1 num_tokens_range = [64] # Single value for CI
# TODO test int8 hidden_dim_range = [1536] # Single value for CI
dst_dtype_range = [fp8_type_] group_size_range = [128] # Keep as is
dst_dtype_range = [fp8_type_] # Keep as is
else:
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
group_size_range = [128] # For DeepSeek V3/R1
# TODO test int8
dst_dtype_range = [fp8_type_]
flags_range = [ flags_range = [
dict( dict(
column_major_scales=False, column_major_scales=False,
...@@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): ...@@ -82,7 +96,7 @@ def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
fn, kernel_names = { fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"),
"sglang": ( "sglang": (
sglang_per_token_group_quant_8bit, sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel", "per_token_group_quant_8bit_kernel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment