Unverified Commit b58c3c28 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support ue8m0 for triton quant kernel (#7603)

parent df906455
...@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -173,6 +173,7 @@ def _per_token_group_quant_fp8_colmajor(
fp8_max, fp8_max,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
): ):
"""A Triton-accelerated function to perform per-token-group """A Triton-accelerated function to perform per-token-group
quantization on a tensor. quantization on a tensor.
...@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -197,6 +198,8 @@ def _per_token_group_quant_fp8_colmajor(
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max y_s = _absmax / fp8_max
if SCALE_UE8M0:
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_q_ptr + cols, y_q, mask=mask)
...@@ -209,6 +212,7 @@ def per_token_group_quant_fp8( ...@@ -209,6 +212,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10, eps: float = 1e-10,
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False, scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`. """Function to perform per-token-group quantization on an input tensor `x`.
...@@ -229,30 +233,18 @@ def per_token_group_quant_fp8( ...@@ -229,30 +233,18 @@ def per_token_group_quant_fp8(
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
M = x.numel() // group_size x_s = create_per_token_group_quant_fp8_output_scale(
N = group_size x_shape=x.shape,
if column_major_scales:
if scale_tma_aligned:
# aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device, device=x.device,
dtype=torch.float32, group_size=group_size,
).permute(-1, -2)[: x.shape[-2], :] column_major_scales=column_major_scales,
else: scale_tma_aligned=scale_tma_aligned,
x_s = torch.empty( scale_ue8m0=False,
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
) )
M = x.numel() // group_size
N = group_size
BLOCK = triton.next_power_of_2(N) BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps # heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8) num_warps = min(max(BLOCK // 256, 1), 8)
...@@ -271,8 +263,10 @@ def per_token_group_quant_fp8( ...@@ -271,8 +263,10 @@ def per_token_group_quant_fp8(
BLOCK=BLOCK, BLOCK=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
SCALE_UE8M0=scale_ue8m0,
) )
else: else:
assert not scale_ue8m0
_per_token_group_quant_fp8[(M,)]( _per_token_group_quant_fp8[(M,)](
x, x,
x_q, x_q,
...@@ -287,57 +281,93 @@ def per_token_group_quant_fp8( ...@@ -287,57 +281,93 @@ def per_token_group_quant_fp8(
num_stages=num_stages, num_stages=num_stages,
) )
if scale_ue8m0:
from deep_gemm.utils.layout import transform_sf_into_required_layout
assert group_size == 128
x_s = transform_sf_into_required_layout(
x_s,
num_groups=None,
mn=x_q.shape[0],
k=x_q.shape[1],
recipe=(1, group_size, group_size),
is_sfa=True,
)
return x_q, x_s return x_q, x_s
def sglang_per_token_group_quant_fp8( def create_per_token_group_quant_fp8_output_scale(
x: torch.Tensor, x_shape,
group_size: int, device,
eps: float = 1e-10, group_size,
column_major_scales: bool = False, column_major_scales: bool,
scale_tma_aligned: bool = False, scale_tma_aligned: bool,
scale_ue8m0: bool = False, scale_ue8m0: bool,
): ):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
if scale_ue8m0: if scale_ue8m0:
assert column_major_scales and scale_tma_aligned assert column_major_scales and scale_tma_aligned
x_q_mn, x_q_k = x.shape x_q_mn, x_q_k = x_shape
x_s_mn, x_s_k = x_q_mn, x_q_k // 128 x_s_mn, x_s_k = x_q_mn, x_q_k // 128
aligned_mn = align(x_s_mn, 4) aligned_mn = align(x_s_mn, 4)
aligned_k = align(x_s_k, 4) aligned_k = align(x_s_k, 4)
# TODO(FIXME): Fix cuda kernel and recover here to empty. # TODO(FIXME): Fix cuda kernel and recover here to empty.
x_s = torch.zeros( return torch.zeros(
(aligned_k // 4, aligned_mn), (aligned_k // 4, aligned_mn),
device=x.device, device=device,
dtype=torch.int, dtype=torch.int,
).transpose(0, 1)[:x_s_mn, :] ).transpose(0, 1)[:x_s_mn, :]
elif column_major_scales: elif column_major_scales:
if scale_tma_aligned: if scale_tma_aligned:
# TODO extract "align" function # TODO extract "align" function
# aligned to 4 * sizeof(float) # aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4 aligned_size = (x_shape[-2] + 3) // 4 * 4
x_s = torch.empty( return torch.empty(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
device=x.device, device=device,
dtype=torch.float32, dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :] ).permute(-1, -2)[: x_shape[-2], :]
else: else:
x_s = torch.empty( return torch.empty(
(x.shape[-1] // group_size,) + x.shape[:-1], (x_shape[-1] // group_size,) + x_shape[:-1],
device=x.device, device=device,
dtype=torch.float32, dtype=torch.float32,
).permute(-1, -2) ).permute(-1, -2)
else: else:
x_s = torch.empty( return torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,), x_shape[:-1] + (x_shape[-1] // group_size,),
device=x.device, device=device,
dtype=torch.float32, dtype=torch.float32,
) )
def sglang_per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
):
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
if scale_ue8m0:
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
assert x.shape[-1] % (group_size * 4) == 0
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
x_s = create_per_token_group_quant_fp8_output_scale(
x_shape=x.shape,
device=x.device,
group_size=group_size,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
)
if x.shape[0] > 0: if x.shape[0] > 0:
sgl_per_token_group_quant_fp8( sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment