Unverified Commit eec9e471 authored by jiahanc's avatar jiahanc Committed by GitHub
Browse files

[NVIDIA] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#11563)


Signed-off-by: default avatarjiahanc <173873397+jiahanc@users.noreply.github.com>
parent 6d535b71
...@@ -49,7 +49,6 @@ from sglang.srt.utils import ( ...@@ -49,7 +49,6 @@ from sglang.srt.utils import (
is_cpu, is_cpu,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
next_power_of_2,
round_up, round_up,
) )
...@@ -72,16 +71,6 @@ if should_use_flashinfer_trtllm_moe(): ...@@ -72,16 +71,6 @@ if should_use_flashinfer_trtllm_moe():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher: def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
a2a_backend = get_moe_a2a_backend() a2a_backend = get_moe_a2a_backend()
if a2a_backend.is_none(): if a2a_backend.is_none():
...@@ -1080,9 +1069,7 @@ class FlashInferFP4MoE(FusedMoE): ...@@ -1080,9 +1069,7 @@ class FlashInferFP4MoE(FusedMoE):
local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts, local_num_experts=self.num_local_experts,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim( tile_tokens_dim=None,
hidden_states.shape[0], topk_config.top_k, self.num_local_experts
),
routing_method_type=RoutingMethodType.DeepSeekV3, routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True, do_finalize=True,
)[0] )[0]
......
...@@ -41,7 +41,6 @@ from sglang.srt.utils import ( ...@@ -41,7 +41,6 @@ from sglang.srt.utils import (
is_triton_kernels_available, is_triton_kernels_available,
log_info_on_rank0, log_info_on_rank0,
mxfp_supported, mxfp_supported,
next_power_of_2,
round_up, round_up,
set_weight_attrs, set_weight_attrs,
) )
...@@ -597,30 +596,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -597,30 +596,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False) layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def create_moe_runner( def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
): ):
...@@ -696,7 +671,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -696,7 +671,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts layer.num_local_experts, # local num experts
None, None,
self._get_tile_tokens_dim(x, top_k), None, # tile_tokens_dim
1, # routing_method_type, renormalize 1, # routing_method_type, renormalize
True, # do finalize True, # do finalize
)[0] )[0]
......
...@@ -45,8 +45,8 @@ else ...@@ -45,8 +45,8 @@ else
# Install the main package without deps # Install the main package without deps
$PIP_CMD install -e "python[dev]" --no-deps $PIP_INSTALL_SUFFIX --force-reinstall $PIP_CMD install -e "python[dev]" --no-deps $PIP_INSTALL_SUFFIX --force-reinstall
# Install flashinfer-python 0.4.0 dependency that requires prerelease (This should be removed when flashinfer fixes this issue) # Install flashinfer-python 0.4.1 dependency that requires prerelease (This should be removed when flashinfer fixes this issue)
$PIP_CMD install flashinfer-python==0.4.0 --prerelease=allow $PIP_INSTALL_SUFFIX $PIP_CMD install flashinfer-python==0.4.1 --prerelease=allow $PIP_INSTALL_SUFFIX
# Install the main package # Install the main package
$PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX --upgrade $PIP_CMD install -e "python[dev]" --extra-index-url https://download.pytorch.org/whl/${CU_VERSION} $PIP_INSTALL_SUFFIX --upgrade
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment