Unverified Commit d389bedf authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

[CPU][Qwen3 MoE] Enable fused_topk CPU fusion and enhance FP8 TP padding (#7838)

parent ac80f4da
...@@ -83,13 +83,18 @@ def fused_topk_cpu( ...@@ -83,13 +83,18 @@ def fused_topk_cpu(
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
): ):
return torch.ops.sgl_kernel.topk_softmax_cpu( topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=gating_output, gating_output=gating_output,
topk=topk, topk=topk,
renormalize=renormalize, renormalize=renormalize,
) )
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def fused_topk( def fused_topk(
...@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available: ...@@ -411,6 +416,7 @@ if _is_cpu and _is_cpu_amx_available:
biased_grouped_topk = biased_grouped_topk_cpu biased_grouped_topk = biased_grouped_topk_cpu
grouped_topk = grouped_topk_cpu grouped_topk = grouped_topk_cpu
fused_topk_native = fused_topk_cpu fused_topk_native = fused_topk_cpu
fused_topk = fused_topk_cpu
else: else:
biased_grouped_topk = biased_grouped_topk_gpu biased_grouped_topk = biased_grouped_topk_gpu
grouped_topk = grouped_topk_gpu grouped_topk = grouped_topk_gpu
......
...@@ -187,10 +187,26 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -187,10 +187,26 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data param_data = self.data
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow( if _is_cpu:
self.output_dim, shard_id * shard_size, shard_size from sglang.srt.model_loader.weight_utils import (
narrow_padded_param_and_loaded_weight,
)
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
param_data,
loaded_weight,
0, # param_data_start
shard_id * shard_size,
self.output_dim,
shard_size,
not use_presharded_weights,
) )
else:
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert ( assert (
param_data.shape == loaded_weight.shape param_data.shape == loaded_weight.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment