Unverified Commit 32fa1e9c authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[4/N] MoE Refactor: Unified Triton Kernel for FusedMoE and EPMoE (#8515)

parent e7dc163f
...@@ -413,18 +413,37 @@ def fused_moe_kernel( ...@@ -413,18 +413,37 @@ def fused_moe_kernel(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
offs_token = offs_token.to(tl.int64) offs_token = offs_token.to(tl.int64)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + ( a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
) )
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = ( b_ptrs = (
b_ptr b_ptr
+ off_experts * stride_be + off_experts * stride_be
...@@ -497,7 +516,6 @@ def fused_moe_kernel( ...@@ -497,7 +516,6 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else: else:
# fix out of shared memory issue
if use_fp8_w8a8: if use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
else: else:
......
...@@ -12,7 +12,7 @@ from sglang.srt.distributed import ( ...@@ -12,7 +12,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -79,7 +79,6 @@ class FusedMoE(torch.nn.Module): ...@@ -79,7 +79,6 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False, enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False, enable_ep_moe: Optional[bool] = False,
skip_quant: Optional[bool] = False,
): ):
super().__init__() super().__init__()
...@@ -95,7 +94,8 @@ class FusedMoE(torch.nn.Module): ...@@ -95,7 +94,8 @@ class FusedMoE(torch.nn.Module):
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map = None self.expert_map_cpu = None
self.expert_map_gpu = None
if enable_flashinfer_cutlass_moe and quant_config is None: if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.") logger.warning("Disable flashinfer MoE when quantization config is None.")
...@@ -104,20 +104,22 @@ class FusedMoE(torch.nn.Module): ...@@ -104,20 +104,22 @@ class FusedMoE(torch.nn.Module):
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
if enable_ep_moe: if enable_ep_moe:
# TODO(ch-wan): support shared experts fusion
self.ep_size = self.tp_size self.ep_size = self.tp_size
self.ep_rank = self.tp_rank self.ep_rank = self.tp_rank
self.tp_size = 1 self.tp_size = 1
self.tp_rank = 0 self.tp_rank = 0
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32) self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts # Create a expert map for the local experts
assert num_experts % self.ep_size == 0 assert num_experts % self.ep_size == 0
self.num_local_experts = num_experts // self.ep_size self.num_local_experts = num_experts // self.ep_size
self.expert_map[ self.expert_map_cpu[
self.ep_rank self.ep_rank
* self.num_local_experts : (self.ep_rank + 1) * self.num_local_experts : (self.ep_rank + 1)
* self.num_local_experts * self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
else: else:
self.ep_size = 1 self.ep_size = 1
self.ep_rank = 0 self.ep_rank = 0
...@@ -136,9 +138,6 @@ class FusedMoE(torch.nn.Module): ...@@ -136,9 +138,6 @@ class FusedMoE(torch.nn.Module):
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"] not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
) )
if skip_quant:
return
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels self.use_triton_kernels
...@@ -367,9 +366,9 @@ class FusedMoE(torch.nn.Module): ...@@ -367,9 +366,9 @@ class FusedMoE(torch.nn.Module):
expert_data.copy_(loaded_weight) expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None: if self.expert_map_cpu is None:
return expert_id return expert_id
return self.expert_map[expert_id].item() return self.expert_map_cpu[expert_id].item()
def weight_loader( def weight_loader(
self, self,
...@@ -421,7 +420,6 @@ class FusedMoE(torch.nn.Module): ...@@ -421,7 +420,6 @@ class FusedMoE(torch.nn.Module):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1: if expert_id == -1:
return return
self._weight_loader_impl( self._weight_loader_impl(
param=param, param=param,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
...@@ -614,9 +612,14 @@ class FusedMoE(torch.nn.Module): ...@@ -614,9 +612,14 @@ class FusedMoE(torch.nn.Module):
) )
return return
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
assert self.quant_method is not None assert self.quant_method is not None
if self.expert_map_gpu is not None:
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
...@@ -670,3 +673,20 @@ class FusedMoE(torch.nn.Module): ...@@ -670,3 +673,20 @@ class FusedMoE(torch.nn.Module):
("w3", ckpt_up_proj_name), ("w3", ckpt_up_proj_name),
] ]
] ]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return [
(
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
f"experts.{expert_id}.{shard_id}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id in ["w1", "w2", "w3"]
]
...@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig): ...@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig): ...@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig):
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self) return Fp8MoEMethod(self)
elif isinstance(layer, EPMoE):
return Fp8EPMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if isinstance(layer, EPMoE):
layer.w13_weight_scale = (
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
)
layer.w2_weight_scale = (
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
)
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
......
...@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
if isinstance(layer, EPMoE):
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
return self.forward( return self.forward(
x=x, x=x,
layer=layer, layer=layer,
......
...@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): ...@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer: EPMoE, layer: EPMoE,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(ch-wan): move it out of this class # TODO(ch-wan): move it out of this class
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment