Unverified Commit 78fe7753 authored by bnellnm's avatar bnellnm Committed by GitHub
Browse files

[Kernel] Enable fp8 support for pplx and BatchedTritonExperts. (#18864)


Signed-off-by: default avatarBill Nell <bnell@redhat.com>
parent 2f2fcb31
...@@ -24,6 +24,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -24,6 +24,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def topk_indices_dtype(self) -> Optional[torch.dtype]: def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None return None
def num_dispatchers(self) -> int:
return 1
def prepare( def prepare(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
......
...@@ -99,9 +99,20 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: ...@@ -99,9 +99,20 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
return m[idx, ...] return m[idx, ...]
# TODO(bnell): better name def normalize_scales_shape(
def maybe_fix_scales(scales: Optional[torch.Tensor], scales: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
num_experts: int) -> Optional[torch.Tensor]: if scales is not None:
if scales.numel() == 1:
scales = scales.view(1, 1)
else:
scales = scales.view(-1, scales.size(-1))
return scales
def normalize_batched_scales_shape(
scales: Optional[torch.Tensor],
num_experts: int,
) -> Optional[torch.Tensor]:
if scales is not None and scales.ndim < 3: if scales is not None and scales.ndim < 3:
if scales.numel() == 1: if scales.numel() == 1:
scales = scales.view(1) scales = scales.view(1)
...@@ -111,3 +122,23 @@ def maybe_fix_scales(scales: Optional[torch.Tensor], ...@@ -111,3 +122,23 @@ def maybe_fix_scales(scales: Optional[torch.Tensor],
scales = scales.view(num_experts, -1, scales.size(-1)) scales = scales.view(num_experts, -1, scales.size(-1))
return scales return scales
def _validate_scale_shape(
a: torch.Tensor,
a_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
) -> None:
if a_scale is None:
return
if not per_act_token_quant and block_shape is None:
assert a_scale.numel() == 1, f"{a_scale.shape}"
elif per_act_token_quant:
assert a_scale.shape[0] == a.shape[0] and a_scale.shape[1] == 1, (
f"{a_scale.shape[0]} == {a.shape[0]} and {a_scale.shape[1]} == 1")
else:
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
...@@ -573,6 +573,41 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -573,6 +573,41 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts self.fused_experts_func = fused_experts
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
)
assert max_num_tokens_per_rank is not None
return BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN),
)
else:
return TritonExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=(
self.input_quant.strategy == QuantizationStrategy.TOKEN),
)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -610,7 +645,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -610,7 +645,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts_func( return self.rocm_aiter_fused_experts_func(
...@@ -832,18 +869,25 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ...@@ -832,18 +869,25 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
use_batched_format = (prepare_finalize.activation_format == use_batched_format = (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts) FusedMoEActivationFormat.BatchedExperts)
num_dispatchers = prepare_finalize.num_dispatchers()
num_experts = (moe.num_local_experts num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts) if use_batched_format else moe.num_experts)
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8( experts = CutlassExpertsFp8(
num_experts, num_experts,
moe.in_dtype, moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN, self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format, use_batched_format=use_batched_format,
) )
self.disable_expert_map = not experts.supports_expert_map() self.disable_expert_map = (num_dispatchers > 1
or not experts.supports_expert_map())
return experts return experts
def apply( def apply(
......
...@@ -802,10 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -802,10 +802,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config.weight_block_size, False) self.quant_config.weight_block_size, False)
return BatchedTritonOrDeepGemmExperts( return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank, max_num_tokens=max_num_tokens_per_rank,
world_size=prepare_finalize. num_dispatchers=prepare_finalize.num_dispatchers(),
world_size, # type: ignore [attr-defined]
dp_size=prepare_finalize.
dp_size, # type: ignore [attr-defined]
use_fp8_w8a8=True, use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
per_act_token_quant=False, per_act_token_quant=False,
......
...@@ -135,7 +135,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -135,7 +135,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
final_hidden_states = final_hidden_states
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states) final_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment