Unverified Commit 21337b22 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Reland [1/2] Optimizations and refactors about quant kernel (#10312)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 129d2992
...@@ -43,11 +43,17 @@ _is_cpu = is_cpu() ...@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, # Temporary
sgl_per_token_quant_fp8, try:
) from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit = False
if _is_hip: if _is_hip:
if _use_aiter: if _use_aiter:
...@@ -477,6 +483,7 @@ def sglang_per_token_group_quant_fp8( ...@@ -477,6 +483,7 @@ def sglang_per_token_group_quant_fp8(
scale_ue8m0: bool = False, scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False, fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None, masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
...@@ -496,9 +503,26 @@ def sglang_per_token_group_quant_fp8( ...@@ -496,9 +503,26 @@ def sglang_per_token_group_quant_fp8(
) )
if x.shape[0] > 0: if x.shape[0] > 0:
sgl_per_token_group_quant_fp8( # Temporary
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 if enable_sgl_per_token_group_quant_8bit:
) sgl_per_token_group_quant_8bit(
x,
x_q,
x_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
enable_v2=enable_v2,
)
else:
assert not enable_v2
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
return x_q, x_s return x_q, x_s
...@@ -514,6 +538,7 @@ def sglang_per_token_group_quant_8bit( ...@@ -514,6 +538,7 @@ def sglang_per_token_group_quant_8bit(
scale_ue8m0: bool = False, scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False, fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None, masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
): ):
from sglang.srt.layers.quantization.int8_kernel import ( from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8, sglang_per_token_group_quant_int8,
...@@ -529,6 +554,7 @@ def sglang_per_token_group_quant_8bit( ...@@ -529,6 +554,7 @@ def sglang_per_token_group_quant_8bit(
group_size=group_size, group_size=group_size,
eps=eps, eps=eps,
dtype=dst_dtype, dtype=dst_dtype,
enable_v2=enable_v2,
) )
return sglang_per_token_group_quant_fp8( return sglang_per_token_group_quant_fp8(
...@@ -540,6 +566,7 @@ def sglang_per_token_group_quant_8bit( ...@@ -540,6 +566,7 @@ def sglang_per_token_group_quant_8bit(
scale_ue8m0=scale_ue8m0, scale_ue8m0=scale_ue8m0,
fuse_silu_and_mul=fuse_silu_and_mul, fuse_silu_and_mul=fuse_silu_and_mul,
masked_m=masked_m, masked_m=masked_m,
enable_v2=enable_v2,
) )
......
...@@ -8,11 +8,17 @@ import torch ...@@ -8,11 +8,17 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.utils import get_device_name, is_cuda from sglang.srt.utils import get_bool_env_var, get_device_name, is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_int8 # Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -187,6 +193,7 @@ def sglang_per_token_group_quant_int8( ...@@ -187,6 +193,7 @@ def sglang_per_token_group_quant_int8(
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = torch.int8, dtype: torch.dtype = torch.int8,
enable_v2: Optional[bool] = None,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
...@@ -204,7 +211,9 @@ def sglang_per_token_group_quant_int8( ...@@ -204,7 +211,9 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) sgl_per_token_group_quant_8bit(
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
)
return x_q, x_s return x_q, x_s
......
import os import os
import re
import sys import sys
from contextlib import nullcontext from contextlib import nullcontext
...@@ -108,7 +109,8 @@ def bench_kineto( ...@@ -108,7 +109,8 @@ def bench_kineto(
if not with_multiple_kernels: if not with_multiple_kernels:
for name in kernel_names: for name in kernel_names:
assert ( assert (
sum([name in line for line in prof_lines]) == 1 sum([int(re.search(name, line) is not None) for line in prof_lines])
== 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})" ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces # Save chrome traces
...@@ -122,7 +124,7 @@ def bench_kineto( ...@@ -122,7 +124,7 @@ def bench_kineto(
total_time = 0 total_time = 0
total_num = 0 total_num = 0
for line in prof_lines: for line in prof_lines:
if name in line: if re.search(name, line) is not None:
time_str = line.split()[-2] time_str = line.split()[-2]
num_str = line.split()[-1] num_str = line.split()[-1]
for unit, scale in units.items(): for unit, scale in units.items():
......
...@@ -287,6 +287,7 @@ set(SOURCES ...@@ -287,6 +287,7 @@ set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels.cu" "csrc/gemm/nvfp4_scaled_mm_kernels.cu"
"csrc/gemm/per_tensor_quant_fp8.cu" "csrc/gemm/per_tensor_quant_fp8.cu"
"csrc/gemm/per_token_group_quant_8bit.cu" "csrc/gemm/per_token_group_quant_8bit.cu"
"csrc/gemm/per_token_group_quant_8bit_v2.cu"
"csrc/gemm/per_token_quant_fp8.cu" "csrc/gemm/per_token_quant_fp8.cu"
"csrc/gemm/qserve_w4a8_per_chn_gemm.cu" "csrc/gemm/qserve_w4a8_per_chn_gemm.cu"
"csrc/gemm/qserve_w4a8_per_group_gemm.cu" "csrc/gemm/qserve_w4a8_per_group_gemm.cu"
......
...@@ -6,6 +6,7 @@ from pathlib import Path ...@@ -6,6 +6,7 @@ from pathlib import Path
import torch import torch
import triton import triton
from sgl_kernel.test_utils import create_per_token_group_quant_test_data
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
create_per_token_group_quant_fp8_output_scale, create_per_token_group_quant_fp8_output_scale,
...@@ -27,84 +28,217 @@ _is_hip = is_hip() ...@@ -27,84 +28,217 @@ _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
# CI environment uses simplified parameters mode_concentrated = IS_CI or (os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated")
if IS_CI:
num_tokens_range = [64] # Single value for CI if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")):
hidden_dim_range = [1536] # Single value for CI configs = [
group_size_range = [128] # Keep as is [
dst_dtype_range = [fp8_type_] # Keep as is 768 * 8,
2048,
128,
48,
fp8_type_,
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
# masked_layout_mode=None,
masked_layout_mode="balanced",
# masked_layout_mode="extreme",
),
]
]
elif mode_concentrated:
configs = list(
itertools.product(
[768],
[1536, 7168, 16384],
[128],
[None],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
],
)
) + list(
itertools.product(
[768 * 8],
[2048],
[128],
[48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
)
else: else:
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] configs = list(
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 itertools.product(
group_size_range = [128] # For DeepSeek V3/R1 [1, 4, 16, 64, 256, 768, 2048, 8192, 16384],
# TODO test int8 [1536, 7168, 16384],
dst_dtype_range = [fp8_type_] [128],
flags_range = [ [None],
dict( [fp8_type_],
column_major_scales=False, [
scale_tma_aligned=False, dict(
scale_ue8m0=False, column_major_scales=False,
), scale_tma_aligned=False,
dict( scale_ue8m0=False,
column_major_scales=True, fuse_silu_and_mul=False,
scale_tma_aligned=False, masked_layout_mode=None,
scale_ue8m0=False, ),
), dict(
dict( column_major_scales=True,
column_major_scales=True, scale_tma_aligned=False,
scale_tma_aligned=True, scale_ue8m0=False,
scale_ue8m0=False, fuse_silu_and_mul=False,
), masked_layout_mode=None,
dict( ),
column_major_scales=True, dict(
scale_tma_aligned=True, column_major_scales=True,
scale_ue8m0=True, scale_tma_aligned=True,
), scale_ue8m0=False,
] fuse_silu_and_mul=False,
masked_layout_mode=None,
),
configs = list( dict(
itertools.product( column_major_scales=True,
num_tokens_range, scale_tma_aligned=True,
hidden_dim_range, scale_ue8m0=True,
group_size_range, fuse_silu_and_mul=False,
dst_dtype_range, masked_layout_mode=None,
flags_range, ),
],
)
) + list(
itertools.product(
[1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
[2048],
[128],
[8, 16, 32, 48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
) )
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"], x_names=[
"num_tokens",
"hidden_dim",
"group_size",
"num_ranks",
"dst_dtype",
"flags",
],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["triton", "sglang"], line_vals=["triton", "sglang"],
line_names=["Triton", "SGL Kernel"], # Triton has multi kernels and we only report the time for the core one
line_names=["Triton (Inaccurate)", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="per-token-group-quant-8bit-performance", plot_name="per-token-group-quant-8bit-performance",
args={}, args={},
) )
) )
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): def benchmark(
if flags["scale_ue8m0"] and group_size != 128: num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider
return ):
print(
device = torch.device("cuda") f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}"
)
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) x, masked_m = create_per_token_group_quant_test_data(
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
)
fn, kernel_names = { fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"), "triton": (
triton_per_token_group_quant_8bit,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel",
),
"sglang": ( "sglang": (
sglang_per_token_group_quant_8bit, partial(sglang_per_token_group_quant_8bit, enable_v2=True),
"per_token_group_quant_8bit_kernel", "per_token_group_quant_8bit_kernel",
), ),
}[provider] }[provider]
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags) bench_fn = lambda: fn(
x=x,
masked_m=masked_m,
group_size=group_size,
dst_dtype=dst_dtype,
**{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
)
time_s = bench_kineto(bench_fn, kernel_names=kernel_names) time_s = bench_kineto(
bench_fn, kernel_names=kernel_names, num_tests=300 if mode_concentrated else 30
)
return time_s * 1e6 return time_s * 1e6
......
...@@ -136,14 +136,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -136,14 +136,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
m.def( m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," "sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); m.impl("sgl_per_token_group_quant_8bit", torch::kCUDA, &sgl_per_token_group_quant_8bit);
m.def( m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size," "sgl_per_token_group_quant_8bit_v2(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()"); " float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8); m.impl("sgl_per_token_group_quant_8bit_v2", torch::kCUDA, &sgl_per_token_group_quant_8bit_v2);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
......
...@@ -121,7 +121,7 @@ void sgl_per_token_group_quant_8bit( ...@@ -121,7 +121,7 @@ void sgl_per_token_group_quant_8bit(
double eps, double eps,
double min_8bit, double min_8bit,
double max_8bit, double max_8bit,
bool scale_ue8m0 = false) { bool scale_ue8m0) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
...@@ -215,26 +215,3 @@ void sgl_per_token_group_quant_8bit( ...@@ -215,26 +215,3 @@ void sgl_per_token_group_quant_8bit(
#undef LAUNCH_KERNEL #undef LAUNCH_KERNEL
} }
void sgl_per_token_group_quant_int8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
}
void sgl_per_token_group_quant_fp8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max,
bool scale_ue8m0) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
}
This diff is collapsed.
...@@ -219,7 +219,7 @@ torch::Tensor fp8_blockwise_scaled_mm( ...@@ -219,7 +219,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
const torch::Dtype& out_dtype); const torch::Dtype& out_dtype);
void scaled_fp4_quant( void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
void sgl_per_token_group_quant_fp8( void sgl_per_token_group_quant_8bit(
at::Tensor input, at::Tensor input,
at::Tensor output_q, at::Tensor output_q,
at::Tensor output_s, at::Tensor output_s,
...@@ -228,14 +228,17 @@ void sgl_per_token_group_quant_fp8( ...@@ -228,14 +228,17 @@ void sgl_per_token_group_quant_fp8(
double fp8_min, double fp8_min,
double fp8_max, double fp8_max,
bool scale_ue8m0); bool scale_ue8m0);
void sgl_per_token_group_quant_int8( void sgl_per_token_group_quant_8bit_v2(
at::Tensor input, at::Tensor input,
at::Tensor output_q, at::Tensor output_q,
at::Tensor output_s, at::Tensor output_s,
int64_t group_size, int64_t group_size,
double eps, double eps,
double int8_min, double min_8bit,
double int8_max); double max_8bit,
bool scale_ue8m0,
bool fuse_silu_and_mul,
const std::optional<torch::Tensor>& masked_m);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void bmm_fp8( void bmm_fp8(
......
...@@ -263,8 +263,7 @@ from sgl_kernel.gemm import ( ...@@ -263,8 +263,7 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant, scaled_fp4_grouped_quant,
scaled_fp4_quant, scaled_fp4_quant,
sgl_per_tensor_quant_fp8, sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_8bit,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
shuffle_rows, shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant,
......
...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm( ...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return output return output
def sgl_per_token_group_quant_fp8( def sgl_per_token_group_quant_8bit(
input: torch.Tensor, input: torch.Tensor,
output_q: torch.Tensor, output_q: torch.Tensor,
output_s: torch.Tensor, output_s: torch.Tensor,
...@@ -106,24 +106,34 @@ def sgl_per_token_group_quant_fp8( ...@@ -106,24 +106,34 @@ def sgl_per_token_group_quant_fp8(
eps: float, eps: float,
fp8_min: float, fp8_min: float,
fp8_max: float, fp8_max: float,
scale_ue8m0: bool, scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( if enable_v2 is None:
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 from sglang.srt.utils import get_bool_env_var
)
enable_v2 = get_bool_env_var("SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2")
if enable_v2:
return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default(
input,
output_q,
output_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
def sgl_per_token_group_quant_int8( assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul"
input: torch.Tensor, assert masked_m is None, "only v2 support masked_m"
output_q: torch.Tensor, torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
output_s: torch.Tensor, input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
group_size: int,
eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
) )
......
import torch
def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags):
device = torch.device("cuda")
dtype = torch.bfloat16
seed = num_tokens * 10000 + hidden_dim
gen_cpu = torch.Generator(device="cpu")
gen_cpu.manual_seed(seed)
gen_cuda = torch.Generator(device="cuda")
gen_cuda.manual_seed(seed)
if flags["fuse_silu_and_mul"]:
effective_hidden_dim = hidden_dim * 2
else:
effective_hidden_dim = hidden_dim
del hidden_dim
if (masked_layout_mode := flags["masked_layout_mode"]) is not None:
num_max_dispatch_tokens_per_rank = 768
num_global_experts = 288
num_local_experts, remainder = divmod(num_global_experts, num_ranks)
assert remainder == 0
# mimic DeepEP low_latency_dispatch output
x = torch.randn(
num_local_experts,
num_max_dispatch_tokens_per_rank * num_ranks,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
if masked_layout_mode == "balanced":
masked_m = _compute_balanced_split(num_tokens, num_local_experts)
elif masked_layout_mode == "imbalanced":
masked_m = _compute_imbalanced_split(
num_tokens, num_local_experts, gen_cpu=gen_cpu
)
elif masked_layout_mode == "extreme":
masked_m = torch.tensor(
[num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int
)
else:
raise NotImplementedError
print(f"{masked_layout_mode=} {masked_m=} {x.shape=}")
masked_m = masked_m.to(device)
return x, masked_m
else:
x = torch.randn(
num_tokens,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10
return x, None
def _compute_balanced_split(total: int, arr_len: int):
base = total // arr_len
remainder = total % arr_len
ans = [base + 1 if i < remainder else base for i in range(arr_len)]
assert sum(ans) == total
return torch.tensor(ans, dtype=torch.int)
def _compute_imbalanced_split(
total: int, arr_len: int, gen_cpu, dtype=torch.int
) -> list[int]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3
noise = noise_raw / noise_raw.sum()
ans = (noise * total).round().to(dtype)
diff = total - ans.sum().item()
while diff != 0:
idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item()
if diff > 0:
ans[idx] += 1
diff -= 1
elif diff < 0 and ans[idx] > 0:
ans[idx] -= 1
diff += 1
assert sum(ans) == total
return ans
def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor):
assert (a.shape == b.shape) and (
a.dtype == b.dtype
), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}"
numel = a.numel()
if a.dtype == torch.float8_e4m3fn:
a_u8 = a.view(torch.uint8)
b_u8 = b.view(torch.uint8)
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
count_tiny_diff = (diff_u8 == 1).sum().item()
count_large_diff = (diff_u8 >= 2).sum().item()
elif a.dtype == torch.int8:
diff = (a.to(torch.int16) - a.to(torch.int16)).abs()
count_diff_sign = ((a >= 0) & (b < 0)).sum().item()
count_tiny_diff = (diff == 1).sum().item()
count_large_diff = (diff >= 2).sum().item()
else:
raise NotImplementedError
assert (
(count_diff_sign == 0)
and (count_large_diff == 0)
and (
(count_tiny_diff / numel < 0.005)
or ((count_tiny_diff / numel < 0.04) and (numel <= 4096))
)
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}"
import itertools import itertools
import os
import time
from pathlib import Path
import pytest import pytest
import torch import torch
from sgl_kernel.test_utils import (
assert_all_close_or_tiny_diff,
create_per_token_group_quant_test_data,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit, per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
) )
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.layers.quantization.utils import assert_fp8_all_close from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
configs = list(
itertools.product(
[1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens
[128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim
[16, 32, 64, 128], # group_size
[None], # num_ranks
[fp8_type_, torch.int8], # dtype
[
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
],
)
) + list(
itertools.product(
[1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
# TODO support more
[2048],
[128],
[8, 16, 32, 48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, dst_dtype, flags", "num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs
list(
itertools.product(
[127, 128, 512, 1024, 4096, 8192], # num_tokens
[256, 512, 1024, 2048, 4096], # hidden_dim
[8, 16, 32, 64, 128], # group_size
# TODO test int8
[fp8_type_], # dtype
[
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
],
)
),
) )
def test_per_token_group_quant_with_column_major( def test_per_token_group_quant_with_column_major(
num_tokens, num_tokens,
hidden_dim, hidden_dim,
group_size, group_size,
num_ranks,
dst_dtype, dst_dtype,
flags, flags,
): ):
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)): print(
pytest.skip() f"{num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=}"
)
arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
if flags["scale_ue8m0"] and (arch_major <= 9):
pytest.skip("Only Blackwell need ue8m0 fusion")
return return
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
pytest.skip("scale_ue8m0 only supported on Blackwell") if (flags["scale_ue8m0"] and (group_size != 128)) or (
(dst_dtype == torch.int8) and flags["column_major_scales"]
):
pytest.skip()
return return
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) x, masked_m = create_per_token_group_quant_test_data(
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs = dict( execute_kwargs = dict(
x=x, x=x,
masked_m=masked_m,
group_size=group_size, group_size=group_size,
eps=1e-10, eps=1e-10,
dst_dtype=dst_dtype, dst_dtype=dst_dtype,
**flags, **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
) )
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs) def _postprocess(x_q, x_s):
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs) if masked_m is not None:
print(f"Mask tokens after {masked_m} to be zero")
# torch.set_printoptions(profile="full") for i in range(len(masked_m)):
# print(f"{x_q_triton=}") x_q[i, masked_m[i] :, :] = 0
# print(f"{x_s_triton=}") x_s[i, masked_m[i] :, :] = 0
# print(f"{x_q_sglang=}") return x_q, x_s
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default") x_q_triton, x_s_triton = _postprocess(
*triton_per_token_group_quant_8bit(**execute_kwargs)
assert_fp8_all_close(x_q_triton, x_q_sglang) )
torch.testing.assert_close( x_q_sglang, x_s_sglang = _postprocess(
x_s_triton.contiguous(), *sglang_per_token_group_quant_8bit(**execute_kwargs, enable_v2=True)
x_s_sglang.contiguous(),
rtol=1e-3,
atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
) )
try:
assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang)
torch.testing.assert_close(
x_s_triton.contiguous(),
x_s_sglang.contiguous(),
rtol=1e-3,
atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
)
except AssertionError:
print(
f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}"
)
print(f"{x=}")
print(f"{masked_m=}")
print(f"{x_q_triton=}")
print(f"{x_s_triton=}")
print(f"{x_q_sglang=}")
print(f"{x_s_sglang=}")
raise
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment