Unverified Commit 0d4891cd authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix DeepGemm for EP low latency case (#20833)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent f56d2996
...@@ -11,7 +11,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( ...@@ -11,7 +11,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_blackwell_deep_gemm_used)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm( ...@@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm(
eps: tl.constexpr, eps: tl.constexpr,
fp8_min: tl.constexpr, fp8_min: tl.constexpr,
fp8_max: tl.constexpr, fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
# Meta --------------------------------------------------------------- # Meta ---------------------------------------------------------------
BLOCK: tl.constexpr, BLOCK: tl.constexpr,
...@@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm( ...@@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm(
y = x * y2 y = x * y2
_absmax = tl.maximum(tl.max(tl.abs(y)), eps) _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
...@@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm( ...@@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm(
eps, eps,
fp8_min, fp8_min,
fp8_max, fp8_max,
is_blackwell_deep_gemm_used(),
BLOCK=group_size, BLOCK=group_size,
num_warps=4, num_warps=4,
) )
...@@ -290,14 +295,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -290,14 +295,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# may lead to better performance. # may lead to better performance.
expected_m = max_num_tokens expected_m = max_num_tokens
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
out=workspace1, workspace1, expert_num_tokens, expected_m)
masked_m=expert_num_tokens,
expected_m=expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens) expert_num_tokens)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
out=output, expert_num_tokens, expected_m)
masked_m=expert_num_tokens,
expected_m=expected_m)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment