Unverified Commit 32fa1e9c authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[4/N] MoE Refactor: Unified Triton Kernel for FusedMoE and EPMoE (#8515)

parent e7dc163f
......@@ -86,79 +86,6 @@ if use_flashinfer_trtllm_moe:
logger = logging.getLogger(__name__)
class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
def __init__(
self,
device,
use_flashinfer: bool = False,
use_per_token_if_dynamic: bool = True,
):
super().__init__()
self.device = device
self.use_flashinfer = use_flashinfer
self.use_per_token_if_dynamic = use_per_token_if_dynamic
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
GroupedGemmRunner._init_flashinfer_wrapper(device)
@classmethod
def _init_flashinfer_wrapper(cls, device):
from flashinfer import SegmentGEMMWrapper
workspace_buffer = torch.empty(
128 * 1024 * 1024, dtype=torch.int8, device=device
)
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
# c = a * b
def forward(
self,
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
batch_size: int,
weight_column_major: bool,
seg_indptr: Optional[torch.Tensor] = None,
weight_indices: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
scale_a: torch.Tensor = None,
scale_b: torch.Tensor = None,
block_shape: Optional[List[int]] = None,
c_dtype=None,
):
if self.use_flashinfer:
# TODO: flashinfer
assert False
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
x=a,
weights=b,
batch_size=batch_size,
weight_column_major=weight_column_major,
seg_indptr=seg_indptr,
weight_indices=weight_indices,
)
else:
assert weight_column_major == True
c = grouped_gemm_triton(
a,
b,
c,
batch_size,
weight_column_major,
seg_indptr,
weight_indices,
use_fp8_w8a8,
scale_a,
scale_b,
block_shape=block_shape,
c_dtype=c_dtype,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
return c
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
# Guess tokens per expert assuming perfect expert distribution first.
num_tokens_per_expert = (num_tokens * top_k) // num_experts
......@@ -190,135 +117,50 @@ class EPMoE(FusedMoE):
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
use_per_token_if_dynamic: bool = True,
):
super().__init__(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
top_k=top_k,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
top_k=top_k,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
enable_ep_moe=True,
skip_quant=True,
)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.layer_id = layer_id
self.num_local_experts, self.expert_map = self.determine_expert_map()
self.start_expert_id = self.ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
self.intermediate_size = intermediate_size
self.use_per_token_if_dynamic = use_per_token_if_dynamic
# TODO(ch-wan): move quant preparation to FusedMoE
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod()
)
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.w13_input_scale = None
self.w2_input_scale = None
self.w13_weight_scale = None
self.w2_weight_scale = None
elif isinstance(quant_config, W4AFp8Config):
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
quant_config
)
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.fp8_dtype = torch.float8_e4m3fn
self.w13_input_scale = None
self.w2_input_scale = None
self.w13_weight_scale = None
self.w2_weight_scale = None
self.activation_scheme = quant_config.moe_activation_scheme
elif isinstance(quant_config, Fp8Config):
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
self.use_fp8_w8a8 = True
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
else:
raise ValueError(f"Unsupported quant_config: {quant_config}")
self.quant_config = quant_config
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
self.grouped_gemm_runner = None
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
"""
Calculates how many experts should be assigned to each rank for EP and
creates a mapping from global to local expert index. Experts are
distributed evenly across ranks. Any remaining are assigned to the
last rank.
Returns:
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
- local_num_experts (int): The number of experts assigned
to the current rank.
- expert_map (Optional[torch.Tensor]): A tensor of shape
(global_num_experts,) mapping from global to local index.
Contains global_num_experts for experts not assigned to the current rank.
Returns None if ep_size is 1.
"""
ep_size = self.ep_size
ep_rank = self.ep_rank
global_num_experts = self.num_experts
assert ep_size > 0
if ep_size == 1:
return (global_num_experts, None)
local_num_experts = global_num_experts // ep_size
expert_map = torch.full(
(global_num_experts,), global_num_experts, dtype=torch.int32
)
if ep_rank < (ep_size - 1):
expert_map[
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
] = torch.arange(0, local_num_experts, dtype=torch.int32)
else:
local_num_experts = global_num_experts - ep_rank * local_num_experts
expert_map[-local_num_experts:] = torch.arange(
0, local_num_experts, dtype=torch.int32
)
return (local_num_experts, expert_map)
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, topk_output)
else:
return self.forward_normal(hidden_states, topk_output)
return super().forward(hidden_states, topk_output)
def forward_deepgemm(
self,
......@@ -477,303 +319,6 @@ class EPMoE(FusedMoE):
)
return output
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
return self.quant_method.apply(self, hidden_states, topk_output)
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
topk_weights, topk_ids, _ = topk_output
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device,
use_flashinfer=False, # TODO: use flashinfer
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
num_experts = self.num_experts
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
topk_ids,
num_experts,
)
gateup_input = torch.empty(
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
device=hidden_states.device,
dtype=(
self.fp8_dtype
if self.use_fp8_w8a8 and not self.use_block_quant
else hidden_states.dtype
),
)
if self.activation_scheme == "dynamic" and not self.use_block_quant:
if self.use_per_token_if_dynamic:
max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
else:
max_value = (
torch.max(hidden_states)
.repeat(self.num_local_experts)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
# PreReorder
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
src2dst,
topk_ids,
self.w13_input_scale,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states.shape[1],
BLOCK_SIZE=512,
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
)
dispose_tensor(hidden_states)
if (
self.activation_scheme == "dynamic"
and not self.use_block_quant
and self.use_per_token_if_dynamic
):
scale = torch.empty(
hidden_states_shape[0] * self.top_k,
device=hidden_states_device,
dtype=torch.float32,
)
scale[src2dst] = (
self.w13_input_scale.unsqueeze(1)
.expand(hidden_states_shape[0], self.top_k)
.reshape(-1)
)
self.w13_input_scale = scale
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
weight_indices_cur_rank = torch.arange(
0,
self.num_local_experts,
device=hidden_states_device,
dtype=torch.int64,
)
# GroupGemm-0
gateup_output = self.grouped_gemm_runner(
a=gateup_input,
b=self.w13_weight,
c=None,
c_dtype=hidden_states_dtype,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale,
scale_b=self.w13_weight_scale,
block_shape=self.block_shape,
)
del gateup_input
# Act
if self.activation_scheme == "dynamic" and not self.use_block_quant:
self.w2_input_scale = None
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=hidden_states_dtype,
)
else:
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states_dtype
),
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
elif self.activation == "gelu":
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output
if self.activation_scheme == "dynamic" and not self.use_block_quant:
if self.use_per_token_if_dynamic:
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
else:
self.w2_input_scale = torch.ones(
self.num_local_experts,
dtype=torch.float32,
device=hidden_states_device,
)
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states_device,
dtype=hidden_states_dtype,
)
down_output = self.grouped_gemm_runner(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr_cur_rank,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale,
scale_b=self.w2_weight_scale,
block_shape=self.block_shape,
)
del down_input
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
0,
BLOCK_SIZE=512,
)
return output
@classmethod
def make_expert_params_mapping(
cls,
ckpt_gate_proj_name: str,
ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
(
(
"experts.w13_"
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
else "experts.w2_"
),
f"experts.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return [
(
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
f"experts.{expert_id}.{shard_id}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id in ["w1", "w2", "w3"]
]
def weight_loader(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
self.layer_id, expert_id
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None:
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
return
expert_id = expert_id - self.start_expert_id
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
return
class DeepEPMoE(EPMoE):
"""
......@@ -905,14 +450,15 @@ class DeepEPMoE(EPMoE):
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if dispatch_output.format.is_deepep_normal():
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm_contiguous(dispatch_output)
else:
return self.forward_normal(dispatch_output)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif dispatch_output.format.is_deepep_ll():
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output)
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
raise ValueError(
f"Dispatch output format {dispatch_output.format} is not supported"
)
def combine(
self,
......@@ -928,185 +474,6 @@ class DeepEPMoE(EPMoE):
forward_batch=forward_batch,
)
def _prepare_for_normal(
self,
hidden_states: torch.Tensor,
topk_idx: torch.Tensor,
):
from sglang.srt.layers.moe.ep_moe.kernels import (
deepep_permute_triton_kernel,
deepep_run_moe_deep_preprocess,
)
if hidden_states.shape[0] == 0:
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states
else:
if _use_aiter:
# skip permutation here as aiter fused_moe has fused inside
reorder_topk_ids = torch.empty(
(0,), device=hidden_states.device, dtype=torch.int64
)
seg_indptr = torch.zeros(
(self.num_experts + 1,),
device=hidden_states.device,
dtype=torch.int64,
)
return reorder_topk_ids, seg_indptr, hidden_states
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
topk_idx, self.num_experts
)
num_total_tokens = reorder_topk_ids.numel()
gateup_input = torch.empty(
(int(num_total_tokens), hidden_states.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# PreReorder
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
hidden_states,
gateup_input,
self.src2dst,
topk_idx,
None,
self.router_topk,
hidden_states.shape[1],
BLOCK_SIZE=512,
)
return reorder_topk_ids, seg_indptr, gateup_input
def forward_normal(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, topk_idx = (
dispatch_output.hidden_states,
dispatch_output.topk_idx,
)
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
hidden_states, topk_idx
)
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
assert self.quant_method is not None
assert self.activation == "silu"
if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
)
if self.activation_scheme == "dynamic" and not self.use_block_quant:
max_value = (
torch.max(hidden_states)
.repeat(self.num_local_experts)
.to(torch.float32)
)
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
weight_indices_cur_rank = torch.arange(
0,
self.num_local_experts,
device=hidden_states.device,
dtype=torch.int64,
)
# GroupGemm-0
if hidden_states.shape[0] > 0:
gateup_output = self.grouped_gemm_runner(
a=hidden_states,
b=self.w13_weight,
c=None,
c_dtype=hidden_states.dtype,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w13_input_scale,
scale_b=(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
block_shape=self.block_shape,
)
else:
gateup_output = torch.empty(
hidden_states.shape[0],
self.w13_weight.shape[1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# Act
down_input = torch.empty(
gateup_output.shape[0],
gateup_output.shape[1] // 2,
device=gateup_output.device,
dtype=(
self.fp8_dtype
if (self.use_fp8_w8a8 and not self.use_block_quant)
else hidden_states_dtype
),
)
if self.w2_input_scale is None and not self.use_block_quant:
self.w2_input_scale = torch.ones(
self.num_local_experts,
dtype=torch.float32,
device=hidden_states_device,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
0,
self.num_local_experts - 1,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
del gateup_output
# GroupGemm-1
down_output = torch.empty(
down_input.shape[0],
self.w2_weight.shape[1],
device=hidden_states_device,
dtype=hidden_states_dtype,
)
if down_input.shape[0] > 0:
down_output = self.grouped_gemm_runner(
a=down_input,
b=self.w2_weight,
c=down_output,
batch_size=self.num_local_experts,
weight_column_major=True,
seg_indptr=seg_indptr,
weight_indices=weight_indices_cur_rank,
use_fp8_w8a8=self.use_fp8_w8a8,
scale_a=self.w2_input_scale,
scale_b=(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
block_shape=self.block_shape,
)
return down_output
def forward_aiter(
self,
dispatch_output: DeepEPNormalOutput,
......
......@@ -413,18 +413,37 @@ def fused_moe_kernel(
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
offs_token = offs_token.to(tl.int64)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(
c_ptr,
stride_cm,
stride_cn,
pid_n,
N,
offs_token,
token_mask,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
compute_type,
)
return
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = (
b_ptr
+ off_experts * stride_be
......@@ -497,7 +516,6 @@ def fused_moe_kernel(
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
# fix out of shared memory issue
if use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator)
else:
......
......@@ -12,7 +12,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -79,7 +79,6 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
skip_quant: Optional[bool] = False,
):
super().__init__()
......@@ -95,7 +94,8 @@ class FusedMoE(torch.nn.Module):
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map = None
self.expert_map_cpu = None
self.expert_map_gpu = None
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
......@@ -104,20 +104,22 @@ class FusedMoE(torch.nn.Module):
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
if enable_ep_moe:
# TODO(ch-wan): support shared experts fusion
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
self.tp_rank = 0
# Create a tensor of size num_experts filled with -1
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts
assert num_experts % self.ep_size == 0
self.num_local_experts = num_experts // self.ep_size
self.expert_map[
self.expert_map_cpu[
self.ep_rank
* self.num_local_experts : (self.ep_rank + 1)
* self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
else:
self.ep_size = 1
self.ep_rank = 0
......@@ -136,9 +138,6 @@ class FusedMoE(torch.nn.Module):
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
if skip_quant:
return
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
......@@ -367,9 +366,9 @@ class FusedMoE(torch.nn.Module):
expert_data.copy_(loaded_weight)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None:
if self.expert_map_cpu is None:
return expert_id
return self.expert_map[expert_id].item()
return self.expert_map_cpu[expert_id].item()
def weight_loader(
self,
......@@ -421,7 +420,6 @@ class FusedMoE(torch.nn.Module):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
......@@ -614,9 +612,14 @@ class FusedMoE(torch.nn.Module):
)
return
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
assert self.quant_method is not None
if self.expert_map_gpu is not None:
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
......@@ -670,3 +673,20 @@ class FusedMoE(torch.nn.Module):
("w3", ckpt_up_proj_name),
]
]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
# (param_name, weight_name, expert_id, shard_id)
return [
(
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
f"experts.{expert_id}.{shard_id}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id in ["w1", "w2", "w3"]
]
......@@ -172,7 +172,6 @@ class Fp8Config(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
......@@ -181,8 +180,6 @@ class Fp8Config(QuantizationConfig):
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
elif isinstance(layer, EPMoE):
return Fp8EPMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
......@@ -984,23 +981,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
if isinstance(layer, EPMoE):
layer.w13_weight_scale = (
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
)
layer.w2_weight_scale = (
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
)
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
......
......@@ -204,14 +204,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
if isinstance(layer, EPMoE):
return layer.run_moe(
hidden_states=x,
topk_output=topk_output,
)
return self.forward(
x=x,
layer=layer,
......
......@@ -276,6 +276,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer: EPMoE,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
**kwargs,
) -> torch.Tensor:
# TODO(ch-wan): move it out of this class
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment