Unverified Commit 8b8f2e74 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support new DeepGEMM input format in silu_and_mul_masked_post_quant_fwd (#7153)

parent 0fc3d992
...@@ -278,6 +278,7 @@ def _silu_and_mul_post_quant_kernel( ...@@ -278,6 +278,7 @@ def _silu_and_mul_post_quant_kernel(
fp8_min, fp8_min,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr, NUM_STAGE: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
): ):
expert_id = tl.program_id(2) expert_id = tl.program_id(2)
token_id = tl.program_id(1) token_id = tl.program_id(1)
...@@ -319,6 +320,8 @@ def _silu_and_mul_post_quant_kernel( ...@@ -319,6 +320,8 @@ def _silu_and_mul_post_quant_kernel(
gate_up = up * gate gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max output_s = _absmax / fp8_max
if SCALE_UE8M0:
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to( output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
output_ptr.dtype.element_ty output_ptr.dtype.element_ty
) )
...@@ -339,6 +342,7 @@ def silu_and_mul_masked_post_quant_fwd( ...@@ -339,6 +342,7 @@ def silu_and_mul_masked_post_quant_fwd(
output_scale: torch.Tensor, output_scale: torch.Tensor,
quant_group_size: int, quant_group_size: int,
masked_m: torch.Tensor, masked_m: torch.Tensor,
scale_ue8m0: bool = False,
): ):
""" """
input shape [expert_num, token_num_padded, hidden_dim] input shape [expert_num, token_num_padded, hidden_dim]
...@@ -395,6 +399,7 @@ def silu_and_mul_masked_post_quant_fwd( ...@@ -395,6 +399,7 @@ def silu_and_mul_masked_post_quant_fwd(
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES, NUM_STAGE=NUM_STAGES,
num_warps=num_warps, num_warps=num_warps,
SCALE_UE8M0=scale_ue8m0,
) )
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment