"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9078c0b99e642943c193ee22b266f47fc21eda31"
Unverified Commit b1b3f0b3 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Partially unify triton per token group quant kernels (#9485)

parent 34e5e11f
...@@ -113,7 +113,7 @@ if supports_custom_op(): ...@@ -113,7 +113,7 @@ if supports_custom_op():
@triton.jit @triton.jit
def _per_token_group_quant_fp8( def _per_token_group_quant_8bit(
# Pointers to inputs and output # Pointers to inputs and output
y_ptr, y_ptr,
y_q_ptr, y_q_ptr,
...@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8( ...@@ -125,8 +125,8 @@ def _per_token_group_quant_fp8(
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for float8
fp8_min, bit8_min,
fp8_max, bit8_max,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
): ):
...@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8( ...@@ -147,16 +147,16 @@ def _per_token_group_quant_fp8(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max y_s = _absmax / bit8_max
y_s_inv = 1.0 / y_s y_s_inv = 1.0 / y_s
y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y * y_s_inv, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
@triton.jit @triton.jit
def _per_token_group_quant_fp8_colmajor( def _per_token_group_quant_8bit_colmajor(
# Pointers to inputs and output # Pointers to inputs and output
y_ptr, y_ptr,
y_q_ptr, y_q_ptr,
...@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -169,8 +169,8 @@ def _per_token_group_quant_fp8_colmajor(
# Avoid to divide zero # Avoid to divide zero
eps, eps,
# Information for float8 # Information for float8
fp8_min, bit8_min,
fp8_max, bit8_max,
# Meta-parameters # Meta-parameters
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
SCALE_UE8M0: tl.constexpr, SCALE_UE8M0: tl.constexpr,
...@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -197,19 +197,20 @@ def _per_token_group_quant_fp8_colmajor(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant # Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max y_s = _absmax / bit8_max
if SCALE_UE8M0: if SCALE_UE8M0:
y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s)))) y_s = tl.exp2(tl.ceil(tl.log2(tl.abs(y_s))))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, bit8_min, bit8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8( def _per_token_group_quant_8bit_raw(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
eps: float = 1e-10, eps: float = 1e-10,
dtype: torch.dtype = fp8_dtype,
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False, scale_tma_aligned: bool = False,
scale_ue8m0: bool = False, scale_ue8m0: bool = False,
...@@ -223,6 +224,7 @@ def per_token_group_quant_fp8( ...@@ -223,6 +224,7 @@ def per_token_group_quant_fp8(
x: The input tenosr with ndim >= 2. x: The input tenosr with ndim >= 2.
group_size: The group size used for quantization. group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero. eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor.
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
...@@ -232,7 +234,21 @@ def per_token_group_quant_fp8( ...@@ -232,7 +234,21 @@ def per_token_group_quant_fp8(
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) if _is_hip:
if dtype == torch.int8:
bit8_max = 127.0
else:
bit8_max = 224.0
bit8_min = -bit8_max # TODO incorrect for int8
else:
if dtype == torch.int8:
info = torch.iinfo(dtype)
else:
info = torch.finfo(dtype)
bit8_max = info.max
bit8_min = info.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
x_s = create_per_token_group_quant_fp8_output_scale( x_s = create_per_token_group_quant_fp8_output_scale(
x_shape=x.shape, x_shape=x.shape,
device=x.device, device=x.device,
...@@ -250,7 +266,7 @@ def per_token_group_quant_fp8( ...@@ -250,7 +266,7 @@ def per_token_group_quant_fp8(
num_warps = min(max(BLOCK // 256, 1), 8) num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1 num_stages = 1
if column_major_scales: if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M,)]( _per_token_group_quant_8bit_colmajor[(M,)](
x, x,
x_q, x_q,
x_s, x_s,
...@@ -258,8 +274,8 @@ def per_token_group_quant_fp8( ...@@ -258,8 +274,8 @@ def per_token_group_quant_fp8(
x.shape[1], x.shape[1],
x_s.stride(1), x_s.stride(1),
eps, eps,
fp8_min=fp8_min, bit8_min=bit8_min,
fp8_max=fp8_max, bit8_max=bit8_max,
BLOCK=BLOCK, BLOCK=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
...@@ -267,15 +283,15 @@ def per_token_group_quant_fp8( ...@@ -267,15 +283,15 @@ def per_token_group_quant_fp8(
) )
else: else:
assert not scale_ue8m0 assert not scale_ue8m0
_per_token_group_quant_fp8[(M,)]( _per_token_group_quant_8bit[(M,)](
x, x,
x_q, x_q,
x_s, x_s,
group_size, group_size,
N, N,
eps, eps,
fp8_min=fp8_min, bit8_min=bit8_min,
fp8_max=fp8_max, bit8_max=bit8_max,
BLOCK=BLOCK, BLOCK=BLOCK,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
...@@ -297,6 +313,117 @@ def per_token_group_quant_fp8( ...@@ -297,6 +313,117 @@ def per_token_group_quant_fp8(
return x_q, x_s return x_q, x_s
# backward compatibility
per_token_group_quant_fp8 = _per_token_group_quant_8bit_raw
def _per_token_group_quant_8bit_fuse_silu_and_mul(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
column_major_scales: bool,
scale_tma_aligned: bool,
scale_ue8m0: bool,
masked_m: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Another way to implement (can be used in e.g. comparison tests)
# from sgl_kernel import silu_and_mul
# x_after_silu_and_mul = silu_and_mul(x)
# return per_token_group_quant_fp8(
# x_after_silu_and_mul,
# group_size=group_size,
# eps=eps,
# column_major_scales=column_major_scales,
# scale_tma_aligned=scale_tma_aligned,
# scale_ue8m0=scale_ue8m0,
# )
from deep_gemm.utils.layout import transform_sf_into_required_layout
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
assert column_major_scales
assert scale_tma_aligned
assert scale_ue8m0
needs_unsqueeze = x.dim() == 2
if needs_unsqueeze:
num_tokens, _ = x.shape
x = x.unsqueeze(0)
assert masked_m is None
masked_m = torch.tensor([num_tokens], device=x.device, dtype=torch.int32)
# Use `zeros` for easier testing
output = torch.zeros(
(*x.shape[:-1], x.shape[-1] // 2),
device=x.device,
dtype=dst_dtype,
)
# Use `zeros` for easier testing
output_scale_for_kernel = torch.zeros(
(*x.shape[:-1], x.shape[-1] // 2 // group_size),
device=x.device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
input=x,
output=output,
output_scale=output_scale_for_kernel,
quant_group_size=group_size,
masked_m=masked_m,
scale_ue8m0=scale_ue8m0,
)
assert group_size == 128
output_scale = transform_sf_into_required_layout(
output_scale_for_kernel,
num_groups=output.shape[0],
mn=output.shape[-2],
k=output.shape[-1],
recipe=(1, group_size, group_size),
is_sfa=True,
)
if needs_unsqueeze:
output = output.squeeze(0)
output_scale = output_scale.squeeze(0)
return output, output_scale
def per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if fuse_silu_and_mul:
return _per_token_group_quant_8bit_fuse_silu_and_mul(
x=x,
group_size=group_size,
dst_dtype=dst_dtype,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
masked_m=masked_m,
)
else:
return _per_token_group_quant_8bit_raw(
x=x,
group_size=group_size,
eps=eps,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
dtype=dst_dtype,
)
def create_per_token_group_quant_fp8_output_scale( def create_per_token_group_quant_fp8_output_scale(
x_shape, x_shape,
device, device,
...@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale( ...@@ -307,16 +434,16 @@ def create_per_token_group_quant_fp8_output_scale(
): ):
if scale_ue8m0: if scale_ue8m0:
assert column_major_scales and scale_tma_aligned assert column_major_scales and scale_tma_aligned
x_q_mn, x_q_k = x_shape *x_batch, x_q_mn, x_q_k = x_shape
x_s_mn, x_s_k = x_q_mn, x_q_k // 128 x_s_mn, x_s_k = x_q_mn, x_q_k // 128
aligned_mn = align(x_s_mn, 4) aligned_mn = align(x_s_mn, 4)
aligned_k = align(x_s_k, 4) aligned_k = align(x_s_k, 4)
# TODO(FIXME): Fix cuda kernel and recover here to empty. # TODO(FIXME): Fix cuda kernel and recover here to empty.
return torch.zeros( return torch.empty(
(aligned_k // 4, aligned_mn), (*x_batch, aligned_k // 4, aligned_mn),
device=device, device=device,
dtype=torch.int, dtype=torch.int,
).transpose(0, 1)[:x_s_mn, :] ).transpose(-1, -2)[..., :x_s_mn, :]
elif column_major_scales: elif column_major_scales:
if scale_tma_aligned: if scale_tma_aligned:
# TODO extract "align" function # TODO extract "align" function
...@@ -341,39 +468,6 @@ def create_per_token_group_quant_fp8_output_scale( ...@@ -341,39 +468,6 @@ def create_per_token_group_quant_fp8_output_scale(
) )
# TODO maybe unify int8 and fp8 code later
def per_token_group_quant_8bit(
x: torch.Tensor,
group_size: int,
dst_dtype: torch.dtype,
eps: float = 1e-10,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
scale_ue8m0: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
if dst_dtype == torch.int8:
assert not column_major_scales
assert not scale_tma_aligned
assert not scale_ue8m0
return per_token_group_quant_int8(
x=x,
group_size=group_size,
eps=eps,
dtype=dst_dtype,
)
return per_token_group_quant_fp8(
x=x,
group_size=group_size,
eps=eps,
column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0,
)
def sglang_per_token_group_quant_fp8( def sglang_per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
...@@ -381,15 +475,19 @@ def sglang_per_token_group_quant_fp8( ...@@ -381,15 +475,19 @@ def sglang_per_token_group_quant_fp8(
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False, scale_tma_aligned: bool = False,
scale_ue8m0: bool = False, scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
): ):
assert ( assert (
x.shape[-1] % group_size == 0 x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`" ), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous" assert x.is_contiguous(), "`x` is not contiguous"
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) out_shape = (*x.shape[:-1], x.shape[-1] // (2 if fuse_silu_and_mul else 1))
x_q = torch.empty(out_shape, device=x.device, dtype=fp8_dtype)
x_s = create_per_token_group_quant_fp8_output_scale( x_s = create_per_token_group_quant_fp8_output_scale(
x_shape=x.shape, x_shape=out_shape,
device=x.device, device=x.device,
group_size=group_size, group_size=group_size,
column_major_scales=column_major_scales, column_major_scales=column_major_scales,
...@@ -414,6 +512,8 @@ def sglang_per_token_group_quant_8bit( ...@@ -414,6 +512,8 @@ def sglang_per_token_group_quant_8bit(
column_major_scales: bool = False, column_major_scales: bool = False,
scale_tma_aligned: bool = False, scale_tma_aligned: bool = False,
scale_ue8m0: bool = False, scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
): ):
from sglang.srt.layers.quantization.int8_kernel import ( from sglang.srt.layers.quantization.int8_kernel import (
sglang_per_token_group_quant_int8, sglang_per_token_group_quant_int8,
...@@ -422,6 +522,8 @@ def sglang_per_token_group_quant_8bit( ...@@ -422,6 +522,8 @@ def sglang_per_token_group_quant_8bit(
if dst_dtype == torch.int8: if dst_dtype == torch.int8:
assert not column_major_scales assert not column_major_scales
assert not scale_tma_aligned assert not scale_tma_aligned
assert not fuse_silu_and_mul
assert masked_m is None
return sglang_per_token_group_quant_int8( return sglang_per_token_group_quant_int8(
x=x, x=x,
group_size=group_size, group_size=group_size,
...@@ -436,6 +538,8 @@ def sglang_per_token_group_quant_8bit( ...@@ -436,6 +538,8 @@ def sglang_per_token_group_quant_8bit(
column_major_scales=column_major_scales, column_major_scales=column_major_scales,
scale_tma_aligned=scale_tma_aligned, scale_tma_aligned=scale_tma_aligned,
scale_ue8m0=scale_ue8m0, scale_ue8m0=scale_ue8m0,
fuse_silu_and_mul=fuse_silu_and_mul,
masked_m=masked_m,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment