Unverified Commit 339f8eef authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[1/2] Optimizations and refactors about quant kernel (#9534)

parent afd9f2f5
import os import os
import re
import sys import sys
from contextlib import nullcontext from contextlib import nullcontext
...@@ -108,7 +109,8 @@ def bench_kineto( ...@@ -108,7 +109,8 @@ def bench_kineto(
if not with_multiple_kernels: if not with_multiple_kernels:
for name in kernel_names: for name in kernel_names:
assert ( assert (
sum([name in line for line in prof_lines]) == 1 sum([int(re.search(name, line) is not None) for line in prof_lines])
== 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})" ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces # Save chrome traces
...@@ -122,7 +124,7 @@ def bench_kineto( ...@@ -122,7 +124,7 @@ def bench_kineto(
total_time = 0 total_time = 0
total_num = 0 total_num = 0
for line in prof_lines: for line in prof_lines:
if name in line: if re.search(name, line) is not None:
time_str = line.split()[-2] time_str = line.split()[-2]
num_str = line.split()[-1] num_str = line.split()[-1]
for unit, scale in units.items(): for unit, scale in units.items():
......
...@@ -43,11 +43,17 @@ _is_cpu = is_cpu() ...@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, # Temporary
sgl_per_token_quant_fp8, try:
) from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit = False
if _is_hip: if _is_hip:
if _use_aiter: if _use_aiter:
...@@ -496,6 +502,21 @@ def sglang_per_token_group_quant_fp8( ...@@ -496,6 +502,21 @@ def sglang_per_token_group_quant_fp8(
) )
if x.shape[0] > 0: if x.shape[0] > 0:
# Temporary
if enable_sgl_per_token_group_quant_8bit:
sgl_per_token_group_quant_8bit(
x,
x_q,
x_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
else:
sgl_per_token_group_quant_fp8( sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
) )
......
...@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda ...@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_int8 # Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8( ...@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s return x_q, x_s
......
import itertools import itertools
import os
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import torch import torch
import triton import triton
from sgl_kernel.test_utils import create_per_token_group_quant_test_data
from sglang.srt.bench_utils import bench_kineto from sglang.srt.bench_utils import bench_kineto
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
...@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip ...@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
mode_concentrated = os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated"
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")):
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 # configs = [[
group_size_range = [128] # For DeepSeek V3/R1 # 768,
# TODO test int8 # 16384,
dst_dtype_range = [fp8_type_] # 128,
flags_range = [ # None,
# fp8_type_,
# dict(
# column_major_scales=True,
# scale_tma_aligned=True,
# scale_ue8m0=True,
# fuse_silu_and_mul=False,
# masked_layout_mode=None,
# ),
# ]]
configs = [
[
768 * 8,
2048,
128,
48,
fp8_type_,
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
# masked_layout_mode=None,
masked_layout_mode="balanced",
# masked_layout_mode="extreme",
),
]
]
elif mode_concentrated:
configs = list(
itertools.product(
[768],
[1536, 7168, 16384],
[128],
[None],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
],
)
) + list(
itertools.product(
[768 * 8],
[2048],
[128],
[48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
)
else:
configs = list(
itertools.product(
[1, 4, 16, 64, 256, 768, 2048, 8192, 16384],
[1536, 7168, 16384],
[128],
[None],
[fp8_type_],
[
dict( dict(
column_major_scales=False, column_major_scales=False,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=True, scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
] ],
)
) + list(
configs = list(
itertools.product( itertools.product(
num_tokens_range, [1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
hidden_dim_range, [2048],
group_size_range, [128],
dst_dtype_range, [8, 16, 32, 48],
flags_range, [fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
) )
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"], x_names=[
"num_tokens",
"hidden_dim",
"group_size",
"num_ranks",
"dst_dtype",
"flags",
],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["triton", "sglang"], line_vals=["triton", "sglang"],
line_names=["Triton", "SGL Kernel"], # Triton has multi kernels and we only report the time for the core one
line_names=["Triton (Inaccurate)", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="per-token-group-quant-8bit-performance", plot_name="per-token-group-quant-8bit-performance",
args={}, args={},
) )
) )
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): def benchmark(
if flags["scale_ue8m0"] and group_size != 128: num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider
return ):
print(
device = torch.device("cuda") f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}"
)
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) x, masked_m = create_per_token_group_quant_test_data(
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
)
fn, kernel_names = { fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), "triton": (
triton_per_token_group_quant_8bit,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel",
),
"sglang": ( "sglang": (
sglang_per_token_group_quant_8bit, sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel", "per_token_group_quant_8bit_kernel",
), ),
}[provider] }[provider]
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags) bench_fn = lambda: fn(
x=x,
masked_m=masked_m,
group_size=group_size,
dst_dtype=dst_dtype,
**{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
)
time_s = bench_kineto(bench_fn, kernel_names=kernel_names) time_s = bench_kineto(
bench_fn, kernel_names=kernel_names, num_tests=300 if mode_concentrated else 30
)
return time_s * 1e6 return time_s * 1e6
......
...@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
m.def( m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," "sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); " float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); m.impl("sgl_per_token_group_quant_8bit", torch::kCUDA, &sgl_per_token_group_quant_8bit);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
......
...@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm( ...@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm(
const torch::Dtype& out_dtype); const torch::Dtype& out_dtype);
void scaled_fp4_quant( void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
void sgl_per_token_group_quant_fp8( void sgl_per_token_group_quant_8bit(
at::Tensor input, at::Tensor input,
at::Tensor output_q, at::Tensor output_q,
at::Tensor output_s, at::Tensor output_s,
int64_t group_size, int64_t group_size,
double eps, double eps,
double fp8_min, double min_8bit,
double fp8_max, double max_8bit,
bool scale_ue8m0); bool scale_ue8m0,
void sgl_per_token_group_quant_int8( bool fuse_silu_and_mul,
at::Tensor input, const std::optional<torch::Tensor>& masked_m);
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void bmm_fp8( void bmm_fp8(
......
...@@ -55,8 +55,7 @@ from sgl_kernel.gemm import ( ...@@ -55,8 +55,7 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant, scaled_fp4_grouped_quant,
scaled_fp4_quant, scaled_fp4_quant,
sgl_per_tensor_quant_fp8, sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_8bit,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
shuffle_rows, shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant,
......
...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm( ...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return output return output
def sgl_per_token_group_quant_fp8( def sgl_per_token_group_quant_8bit(
input: torch.Tensor, input: torch.Tensor,
output_q: torch.Tensor, output_q: torch.Tensor,
output_s: torch.Tensor, output_s: torch.Tensor,
...@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8( ...@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
eps: float, eps: float,
fp8_min: float, fp8_min: float,
fp8_max: float, fp8_max: float,
scale_ue8m0: bool, scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 input,
) output_q,
output_s,
group_size,
def sgl_per_token_group_quant_int8( eps,
input: torch.Tensor, fp8_min,
output_q: torch.Tensor, fp8_max,
output_s: torch.Tensor, scale_ue8m0,
group_size: int, fuse_silu_and_mul,
eps: float, masked_m,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
) )
......
import torch
def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags):
device = torch.device("cuda")
dtype = torch.bfloat16
seed = num_tokens * 10000 + hidden_dim
gen_cpu = torch.Generator(device="cpu")
gen_cpu.manual_seed(seed)
gen_cuda = torch.Generator(device="cuda")
gen_cuda.manual_seed(seed)
if flags["fuse_silu_and_mul"]:
effective_hidden_dim = hidden_dim * 2
else:
effective_hidden_dim = hidden_dim
del hidden_dim
if (masked_layout_mode := flags["masked_layout_mode"]) is not None:
num_max_dispatch_tokens_per_rank = 768
num_global_experts = 288
num_local_experts, remainder = divmod(num_global_experts, num_ranks)
assert remainder == 0
# mimic DeepEP low_latency_dispatch output
x = torch.randn(
num_local_experts,
num_max_dispatch_tokens_per_rank * num_ranks,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
if masked_layout_mode == "balanced":
masked_m = _compute_balanced_split(num_tokens, num_local_experts)
elif masked_layout_mode == "imbalanced":
masked_m = _compute_imbalanced_split(
num_tokens, num_local_experts, gen_cpu=gen_cpu
)
elif masked_layout_mode == "extreme":
masked_m = torch.tensor(
[num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int
)
else:
raise NotImplementedError
print(f"{masked_layout_mode=} {masked_m=} {x.shape=}")
masked_m = masked_m.to(device)
return x, masked_m
else:
x = torch.randn(
num_tokens,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10
return x, None
def _compute_balanced_split(total: int, arr_len: int):
base = total // arr_len
remainder = total % arr_len
ans = [base + 1 if i < remainder else base for i in range(arr_len)]
assert sum(ans) == total
return torch.tensor(ans, dtype=torch.int)
def _compute_imbalanced_split(
total: int, arr_len: int, gen_cpu, dtype=torch.int
) -> list[int]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3
noise = noise_raw / noise_raw.sum()
ans = (noise * total).round().to(dtype)
diff = total - ans.sum().item()
while diff != 0:
idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item()
if diff > 0:
ans[idx] += 1
diff -= 1
elif diff < 0 and ans[idx] > 0:
ans[idx] -= 1
diff += 1
assert sum(ans) == total
return ans
def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor):
assert (a.shape == b.shape) and (
a.dtype == b.dtype
), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}"
numel = a.numel()
if a.dtype == torch.float8_e4m3fn:
a_u8 = a.view(torch.uint8)
b_u8 = b.view(torch.uint8)
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
count_tiny_diff = (diff_u8 == 1).sum().item()
count_large_diff = (diff_u8 >= 2).sum().item()
elif a.dtype == torch.int8:
diff = (a.to(torch.int16) - a.to(torch.int16)).abs()
count_diff_sign = ((a >= 0) & (b < 0)).sum().item()
count_tiny_diff = (diff == 1).sum().item()
count_large_diff = (diff >= 2).sum().item()
else:
raise NotImplementedError
assert (
(count_diff_sign == 0)
and (count_large_diff == 0)
and (
(count_tiny_diff / numel < 0.005)
or ((count_tiny_diff / numel < 0.04) and (numel <= 4096))
)
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}"
import itertools import itertools
import os
import time
from pathlib import Path
import pytest import pytest
import torch import torch
from sgl_kernel.test_utils import (
assert_all_close_or_tiny_diff,
create_per_token_group_quant_test_data,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit, per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
) )
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.layers.quantization.utils import assert_fp8_all_close from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
configs = list(
@pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, dst_dtype, flags",
list(
itertools.product( itertools.product(
[127, 128, 512, 1024, 4096, 8192], # num_tokens [1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens
[256, 512, 1024, 2048, 4096], # hidden_dim [128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim
[8, 16, 32, 64, 128], # group_size [16, 32, 64, 128], # group_size
# TODO test int8 [None], # num_ranks
[fp8_type_], # dtype [fp8_type_, torch.int8], # dtype
[ [
dict( dict(
column_major_scales=False, column_major_scales=False,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=True, scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
], ],
) )
) + list(
itertools.product(
[1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
# TODO support more
[2048],
[128],
[8, 16, 32, 48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
), ),
],
)
)
@pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs
) )
def test_per_token_group_quant_with_column_major( def test_per_token_group_quant_with_column_major(
num_tokens, num_tokens,
hidden_dim, hidden_dim,
group_size, group_size,
num_ranks,
dst_dtype, dst_dtype,
flags, flags,
): ):
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)): print(
pytest.skip() f"{num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=}"
)
arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
if flags["scale_ue8m0"] and (arch_major <= 9):
pytest.skip("Only Blackwell need ue8m0 fusion")
return return
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
pytest.skip("scale_ue8m0 only supported on Blackwell") if (flags["scale_ue8m0"] and (group_size != 128)) or (
(dst_dtype == torch.int8) and flags["column_major_scales"]
):
pytest.skip()
return return
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) x, masked_m = create_per_token_group_quant_test_data(
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs = dict( execute_kwargs = dict(
x=x, x=x,
masked_m=masked_m,
group_size=group_size, group_size=group_size,
eps=1e-10, eps=1e-10,
dst_dtype=dst_dtype, dst_dtype=dst_dtype,
**flags, **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
) )
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs) def _postprocess(x_q, x_s):
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs) if masked_m is not None:
print(f"Mask tokens after {masked_m} to be zero")
for i in range(len(masked_m)):
x_q[i, masked_m[i] :, :] = 0
x_s[i, masked_m[i] :, :] = 0
return x_q, x_s
# torch.set_printoptions(profile="full") x_q_triton, x_s_triton = _postprocess(
# print(f"{x_q_triton=}") *triton_per_token_group_quant_8bit(**execute_kwargs)
# print(f"{x_s_triton=}") )
# print(f"{x_q_sglang=}") x_q_sglang, x_s_sglang = _postprocess(
# print(f"{x_s_sglang=}") *sglang_per_token_group_quant_8bit(**execute_kwargs)
# torch.set_printoptions(profile="default") )
assert_fp8_all_close(x_q_triton, x_q_sglang) try:
assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang)
torch.testing.assert_close( torch.testing.assert_close(
x_s_triton.contiguous(), x_s_triton.contiguous(),
x_s_sglang.contiguous(), x_s_sglang.contiguous(),
...@@ -91,6 +165,35 @@ def test_per_token_group_quant_with_column_major( ...@@ -91,6 +165,35 @@ def test_per_token_group_quant_with_column_major(
atol=1e-5, atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
) )
except AssertionError:
# torch.set_printoptions(profile="full")
print(
f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}"
)
print(f"{x=}")
print(f"{masked_m=}")
print(f"{x_q_triton=}")
print(f"{x_s_triton=}")
print(f"{x_q_sglang=}")
print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
# import matplotlib.pyplot as plt
#
# base_stem = time.time()
# for name, value in [
# ("x_q", x_q_triton != x_q_sglang),
# ("x_s", x_s_triton != x_s_sglang),
# ]:
# value = value.reshape((-1, value.shape[-1]))
# plt.figure(figsize=(20, 20))
# plt.imshow((value * 1.0).cpu().numpy())
# p = Path(d) / f"{base_stem}_{name}.png"
# print(f"Write diff to {p}", flush=True)
# plt.savefig(p)
raise
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment