Unverified Commit e85cb1ce authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix quant kernel test errors and benchmark wrong output speeds (#7604)

parent 55d336cb
...@@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale( ...@@ -341,6 +341,39 @@ def create_per_token_group_quant_fp8_output_scale(
) )
# TODO maybe unify int8 and fp8 code later
def per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
if dst_dtype == torch.int8:
assert not column_major_scales
assert not scale_tma_aligned
assert not scale_ue8m0
return per_token_group_quant_int8(
x=x,
group_size=group_size,
eps=eps,
dtype=dst_dtype,
)
return per_token_group_quant_fp8(
x=x,
group_size=group_size,
eps=eps,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
)
def sglang_per_token_group_quant_fp8( def sglang_per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
...@@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8( ...@@ -372,6 +405,40 @@ def sglang_per_token_group_quant_fp8(
return x_q, x_s return x_q, x_s
# TODO maybe unify int8 and fp8 code later
def sglang_per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
):
from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8,
)
if dst_dtype == torch.int8:
assert not column_major_scales
assert not scale_tma_aligned
return sglang_per_token_group_quant_int8(
x=x,
group_size=group_size,
eps=eps,
dtype=dst_dtype,
)
return sglang_per_token_group_quant_fp8(
x=x,
group_size=group_size,
eps=eps,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
)
def sglang_per_token_quant_fp8( def sglang_per_token_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
dtype: torch.dtype = fp8_dtype, dtype: torch.dtype = fp8_dtype,
......
...@@ -176,6 +176,27 @@ def replace_parameter( ...@@ -176,6 +176,27 @@ def replace_parameter(
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
def assert_fp8_all_close(a: torch.Tensor, b: torch.Tensor):
assert a.shape == b.shape
assert a.dtype == b.dtype == torch.float8_e4m3fn
a_u8 = a.view(torch.uint8)
b_u8 = b.view(torch.uint8)
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
numel = a.numel()
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
count_tiny_diff = (diff_u8 >= 1).sum().item()
count_large_diff = (diff_u8 >= 2).sum().item()
assert (
(count_diff_sign == 0)
and (count_tiny_diff / numel < 0.005)
and (count_large_diff == 0)
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=}"
# Match dynamic rules with module name (prefix) and override quantize # Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule # config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str): def override_config(config: QuantizationConfig, prefix: str):
......
import itertools import itertools
from typing import Tuple import time
from functools import partial
from pathlib import Path
import torch import torch
import triton import triton
import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
from sglang.srt.bench_utils import bench_kineto
from sglang.srt.layers.quantization.fp8_kernel import (
create_per_token_group_quant_fp8_output_scale,
)
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@triton.jit num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
def _per_token_group_quant_8bit( hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Columns of input
N,
# Avoid to divide zero
eps,
# Information for 8bit data type (int8 or fp8_type_)
max_8bit,
min_8bit,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into 8bit values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / max_8bit
y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def triton_per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
if dst_dtype == torch.int8:
iinfo = torch.iinfo(dst_dtype)
max_8bit = iinfo.max
min_8bit = iinfo.min
else:
finfo = torch.finfo(dst_dtype)
max_8bit = finfo.max
min_8bit = finfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
M = x.numel() // group_size
N = group_size
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_8bit[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
max_8bit,
min_8bit,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def sglang_per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
if dst_dtype == torch.int8:
iinfo = torch.iinfo(dst_dtype)
int8_max = iinfo.max
int8_min = iinfo.min
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
else:
f8_info = torch.finfo(dst_dtype)
fp8_max = f8_info.max
fp8_min = f8_info.min
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
return x_q, x_s
def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
device = torch.device("cuda")
hidden_dim = 7168
x = torch.randn(
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
x.clone(), group_size, dst_dtype
)
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(
x.clone(), group_size, dst_dtype
)
if torch.allclose(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5):
print(f"✅ {dst_dtype} implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [1, 2, 4, 8, 16, 32, 64]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
group_size_range = [128] # For DeepSeek V3/R1 group_size_range = [128] # For DeepSeek V3/R1
dst_dtype_range = [torch.int8, fp8_type_] # TODO test int8
dst_dtype_range = [fp8_type_]
flags_range = [
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
]
configs = list( configs = list(
itertools.product( itertools.product(
batch_size_range, seq_len_range, group_size_range, dst_dtype_range num_tokens_range,
hidden_dim_range,
group_size_range,
dst_dtype_range,
flags_range,
) )
) )
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "group_size", "dst_dtype"], x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["triton", "sglang"], line_vals=["triton", "sglang"],
...@@ -194,29 +73,26 @@ configs = list( ...@@ -194,29 +73,26 @@ configs = list(
args={}, args={},
) )
) )
def benchmark(batch_size, seq_len, group_size, dst_dtype, provider): def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
device = torch.device("cuda") if flags["scale_ue8m0"] and group_size != 128:
hidden_dim = 7168 return
x = torch.randn( device = torch.device("cuda")
batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
)
quantiles = [0.5, 0.2, 0.8]
if provider == "triton": x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
fn = lambda: triton_per_token_group_quant_8bit(x, group_size, dst_dtype)
elif provider == "sglang":
fn = lambda: sglang_per_token_group_quant_8bit(x, group_size, dst_dtype)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
"sglang": (
sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel",
),
}[provider]
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
return time_s * 1e6
if __name__ == "__main__": if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=torch.int8)
calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=fp8_type_)
benchmark.run(print_data=True) benchmark.run(print_data=True)
import itertools import itertools
from typing import Tuple
import pytest import pytest
import torch import torch
import triton
import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.layers.quantization.utils import assert_fp8_all_close
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Columns of input
N,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group quantization on a
tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * group_size
y_q_ptr += g_id * group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def triton_per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
if dtype == torch.int8:
finfo = torch.iinfo(dtype)
else:
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if _is_hip:
if dtype == torch.int8:
fp8_max = 127.0
else:
fp8_max = 224.0
fp8_min = -fp8_max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
if scale_tma_aligned:
# aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def sglang_per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: torch.dtype = fp8_type_,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
if scale_tma_aligned:
# aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
if dtype == torch.int8:
iinfo = torch.iinfo(dtype)
int8_max = iinfo.max
int8_min = iinfo.min
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
else:
f8_info = torch.finfo(dtype)
fp8_max = f8_info.max
fp8_min = f8_info.min
scale_ue8m0 = False # TODO also test true
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
return x_q, x_s
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, dst_dtype, column_major_scales, scale_tma_aligned", "num_tokens, hidden_dim, group_size, dst_dtype, flags",
list( list(
itertools.product( itertools.product(
[127, 128, 512, 1024, 4096, 8192], # num_tokens [127, 128, 512, 1024, 4096, 8192], # num_tokens
[256, 512, 1024, 2048, 4096], # hidden_dim [256, 512, 1024, 2048, 4096], # hidden_dim
[8, 16, 32, 64, 128], # group_size [8, 16, 32, 64, 128], # group_size
[torch.int8, fp8_type_], # dtype # TODO test int8
[False, True], # column_major_scales [fp8_type_], # dtype
[False, True], # scale_tma_aligned [
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
],
) )
), ),
) )
...@@ -281,37 +54,42 @@ def test_per_token_group_quant_with_column_major( ...@@ -281,37 +54,42 @@ def test_per_token_group_quant_with_column_major(
hidden_dim, hidden_dim,
group_size, group_size,
dst_dtype, dst_dtype,
column_major_scales, flags,
scale_tma_aligned,
): ):
if not column_major_scales and scale_tma_aligned: if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
pytest.skip()
return
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
pytest.skip("scale_ue8m0 only supported on Blackwell")
return return
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float16) x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit( execute_kwargs = dict(
x, x=x,
group_size, group_size=group_size,
eps=1e-10, eps=1e-10,
dtype=dst_dtype, dst_dtype=dst_dtype,
column_major_scales=column_major_scales, **flags,
scale_tma_aligned=scale_tma_aligned,
) )
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit( x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
x, x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)
group_size,
eps=1e-10,
dtype=dst_dtype,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
)
# torch.set_printoptions(profile="full")
# print(f"{x_q_triton=}")
# print(f"{x_s_triton=}")
# print(f"{x_q_sglang=}")
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
assert_fp8_all_close(x_q_triton, x_q_sglang)
torch.testing.assert_close( torch.testing.assert_close(
x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5 x_s_triton.contiguous(),
) x_s_sglang.contiguous(),
torch.testing.assert_close( rtol=1e-3,
x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5 atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment