"docs/vscode:/vscode.git/clone" did not exist on "be8168ff889aa8981d4e8a158fc1b4d0a4deb18b"
Unverified Commit ac45c44d authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[Bugfix] [Performance] DeepEPHighThroughput + DeepSeek : Quant before Dispatch (#21837)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent d6664664
...@@ -144,12 +144,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -144,12 +144,13 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1") "apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.per_act_token_quant: if quant_config.is_block_quantized:
# Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
a1_scale, a1_scale,
quant_dtype=quant_config.quant_dtype, quant_dtype=quant_config.quant_dtype,
per_act_token_quant=True, per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape, block_shape=quant_config.block_shape,
) )
if a1q_scale is not None and a1q_scale.numel() == 1: if a1q_scale is not None and a1q_scale.numel() == 1:
...@@ -162,8 +163,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -162,8 +163,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
rank_topk_weights=topk_weights, rank_topk_weights=topk_weights,
num_experts=num_experts) num_experts=num_experts)
else: else:
# DeepEP kernels only support dispatching per-token-quant # Dispatch and Quant
# quantization. dispatch in bfloat16. # DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16
(expert_x, _, expert_tokens_meta, expert_topk_ids, (expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch( expert_topk_weights) = self._do_dispatch(
tokens=a1, tokens=a1,
...@@ -171,7 +174,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -171,7 +174,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
rank_topk_ids=topk_ids, rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights, rank_topk_weights=topk_weights,
num_experts=num_experts) num_experts=num_experts)
# quantize now # Quantize after dispatch.
expert_x_scale = None expert_x_scale = None
if expert_x.numel() != 0: if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input( expert_x, expert_x_scale = moe_kernel_quantize_input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment