Unverified Commit 1d24db83 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Expert Parallelism for GPT-OSS (#8944)

parent 44401358
......@@ -76,6 +76,9 @@ class EPMoE(FusedMoE):
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
num_experts=num_experts,
......@@ -91,6 +94,9 @@ class EPMoE(FusedMoE):
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
with_bias=with_bias,
)
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
......
......@@ -319,6 +319,7 @@ def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
bias_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
......@@ -340,6 +341,8 @@ def fused_moe_kernel(
stride_be,
stride_bk,
stride_bn,
stride_bias_e,
stride_bias_n,
stride_cm,
stride_cn,
stride_asm,
......@@ -449,6 +452,10 @@ def fused_moe_kernel(
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if bias_ptr is not None:
bias = tl.load(
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
)
if use_int8_w8a16:
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
......@@ -526,18 +533,20 @@ def fused_moe_kernel(
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
accumulator *= b_scale
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
if group_k == 0 or group_n == 0:
accumulator *= a_scale * b_scale
if bias_ptr is not None:
accumulator += bias
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator *= moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
......@@ -622,6 +631,7 @@ def moe_align_block_size(
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
bias: Optional[torch.Tensor],
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
......@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
assert bias is None
fused_moe_kernel_gptq_awq[grid](
A,
B,
......@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
fused_moe_kernel[grid](
A,
B,
bias,
C,
A_scale,
B_scale,
......@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
B.stride(0),
B.stride(2),
B.stride(1),
bias.stride(0) if bias is not None else 0,
bias.stride(1) if bias is not None else 0,
C.stride(1),
C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
......@@ -994,6 +1008,8 @@ def inplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
......@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> None:
fused_experts_impl(
hidden_states,
......@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
w2,
topk_weights,
topk_ids,
b1,
b2,
True,
activation,
apply_router_weight_on_input,
......@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
block_shape,
False,
routed_scaling_factor,
activation_alpha,
swiglu_limit,
)
......@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
......@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> None:
pass
......@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
......@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
......@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
w2,
topk_weights,
topk_ids,
b1,
b2,
False,
activation,
apply_router_weight_on_input,
......@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
......@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
......@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1157,6 +1195,8 @@ def fused_experts(
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
......@@ -1174,6 +1214,8 @@ def fused_experts(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
):
topk_weights, topk_ids, _ = topk_output
if inplace:
......@@ -1184,6 +1226,8 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
b1,
b2,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
......@@ -1199,6 +1243,8 @@ def fused_experts(
a2_scale,
block_shape,
routed_scaling_factor,
activation_alpha,
swiglu_limit,
)
return hidden_states
else:
......@@ -1208,6 +1254,8 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
b1,
b2,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
......@@ -1224,6 +1272,8 @@ def fused_experts(
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
......@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
out.mul_(routed_scaling_factor)
@torch.compile
def swiglu_with_alpha_and_limit(x, alpha, limit):
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
return gate * torch.sigmoid(gate * alpha) * (up + 1)
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
......@@ -1342,6 +1402,8 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
):
padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
......@@ -1353,7 +1415,7 @@ def fused_experts_impl(
else:
assert (
hidden_states.shape[1] == w1.shape[2] - padded_size
), "Hidden size mismatch"
), f"Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
......@@ -1449,6 +1511,7 @@ def fused_experts_impl(
invoke_fused_moe_kernel(
curr_hidden_states,
w1,
b1,
intermediate_cache1,
a1_scale,
w1_scale,
......@@ -1470,13 +1533,24 @@ def fused_experts_impl(
block_shape=block_shape,
)
if activation == "silu":
if _is_cuda:
if activation_alpha is not None:
assert swiglu_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
activation_alpha,
swiglu_limit,
)
elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
vllm_ops.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
assert (
activation_alpha is None
), "activation_alpha is not supported for gelu"
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
......@@ -1489,6 +1563,7 @@ def fused_experts_impl(
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
b2,
(
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
......@@ -1567,6 +1642,8 @@ def fused_moe(
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
......@@ -1584,6 +1661,8 @@ def fused_moe(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1594,6 +1673,8 @@ def fused_moe(
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_output (TopKOutput): The top-k output of the experts.
- b1 (Optional[torch.Tensor]): Optional bias for w1.
- b2 (Optional[torch.Tensor]): Optional bias for w2.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
......@@ -1615,6 +1696,10 @@ def fused_moe(
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
- activation_alpha (Optional[float]): Optional alpha for the activation
function.
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
function.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
......@@ -1625,6 +1710,8 @@ def fused_moe(
w1,
w2,
topk_output,
b1=b1,
b2=b2,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
......@@ -1642,4 +1729,6 @@ def fused_moe(
block_shape=block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
......@@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels, with_bias=with_bias
self.use_triton_kernels
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
......@@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module):
# If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None:
if self.expert_map_gpu is not None and isinstance(
topk_output, StandardTopKOutput
):
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
......
......@@ -8,6 +8,7 @@ import logging
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.base_config import (
......@@ -24,6 +25,7 @@ from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_hip,
is_triton_kernels_available,
log_info_on_rank0,
next_power_of_2,
round_up,
......@@ -31,7 +33,7 @@ from sglang.srt.utils import (
)
_is_sm100_supported = is_cuda() and is_sm100_supported()
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
has_triton_kernels = is_triton_kernels_available()
if is_flashinfer_available():
......@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
use_flashinfer = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
return Mxfp4MoEMethod(
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
)
return Mxfp4MoEMethod(prefix)
else:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
......@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(
self,
use_triton_kernels: bool = True,
with_bias: bool = True,
use_flashinfer: bool = False,
prefix: str,
):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__()
self.topk_indices_dtype = None
self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias
self.use_flashinfer = use_flashinfer
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.with_bias = False
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
......@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs,
):
# print(f"hi {self=} create_weights {layer=}")
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
self.with_bias = with_bias
mxfp4_block = 32
# pad the intermediate size to be a multiple of 2 * mxfp4_block
......@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // 2,
dtype=weight_dtype,
......@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // mxfp4_block,
dtype=scale_dtype,
......@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16,
),
......@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
hidden_size,
intermediate_size_per_partition_after_pad // 2,
dtype=weight_dtype,
......@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
hidden_size,
intermediate_size_per_partition_after_pad // mxfp4_block,
dtype=scale_dtype,
......@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w2_weight_bias", w2_weight_bias)
......@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
return
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
if self.use_triton_kernels:
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
num_warps = 8
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
num_warps = 8
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
del layer.w13_weight
del layer.w2_weight
else:
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
w13_weight = upcast_from_mxfp(
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
)
w2_weight = upcast_from_mxfp(
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
)
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w2_weight_scale
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
......@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
self.num_experts,
layer.num_experts,
top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset
self.num_experts, # local num experts
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts
None,
self._get_tile_tokens_dim(x, top_k),
1, # routing_method_type, renormalize
......@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return trtllm_gen_output
if self.use_triton_kernels:
assert (
layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels"
if self.with_bias:
# TODO why we do not put weights on layer?
assert layer.w13_weight is None
assert layer.w2_weight is None
return self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
......@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output=topk_output,
)
else:
raise NotImplementedError()
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
......@@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False):
def __init__(self, use_triton_kernels: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias
self.with_bias = False
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
......@@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs,
):
self.with_bias = with_bias
# Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
if self.use_triton_kernels:
......@@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
b1=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
def forward_cpu(
......
......@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module):
self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit
if self.tp_size > config.num_local_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_local_experts}."
)
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
......@@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module):
loaded_params: set[str] = set()
mxfp4_block = 32
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
moe_tp_rank = get_moe_tensor_parallel_rank()
moe_tp_size = get_moe_tensor_parallel_world_size()
moe_ep_rank = get_moe_expert_parallel_rank()
moe_ep_size = get_moe_expert_parallel_world_size()
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = intermediate_size_block // tp_size
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
# Attention heads per rank
heads_per_rank = self.config.num_attention_heads // tp_size
head_start = tp_rank * heads_per_rank
num_experts = self.config.num_local_experts
assert self.config.num_local_experts % moe_ep_size == 0
moe_num_global_experts = self.config.num_local_experts
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
moe_tp_rank_end = min(
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
)
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
for name, weight in weights:
weight = weight.cuda()
......@@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module):
# flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view(
num_experts, 2 * intermediate_size, -1
moe_num_global_experts, 2 * intermediate_size, -1
).contiguous()
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
...,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
......@@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module):
# same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2
weight = weight.view(
num_experts, -1, intermediate_size // 2
moe_num_global_experts, -1, intermediate_size // 2
).contiguous()
narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
...,
moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
......@@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module):
elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
...,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
......@@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module):
# Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale")
narrow_weight = weight[
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
moe_ep_rank_start:moe_ep_rank_end,
...,
moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
]
param = params_dict[new_name]
......@@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module):
# Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
]
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
......@@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module):
loaded_params.add(new_name)
elif "down_proj_bias" in name:
if get_moe_tensor_parallel_rank() != 0:
weight = torch.zeros_like(weight)
narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
if moe_tp_rank != 0:
narrow_weight = torch.zeros_like(narrow_weight)
# Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_weight_bias")
param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(
param, weight, weight_name=new_name, shard_id=None, expert_id=None
param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
)
loaded_params.add(new_name)
......@@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module):
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"):
expert_params_mapping = (
get_moe_impl_class().make_expert_params_mapping_fused_mxfp4(
ckpt_gate_up_proj_name="gate_up_proj_blocks",
ckpt_down_proj_name="down_proj_blocks",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
ckpt_gate_up_proj_scale_name="gate_up_proj_scales",
ckpt_down_proj_scale_name="down_proj_scales",
)
)
else:
expert_params_mapping = (
get_moe_impl_class().make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
)
)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
)
params_dict = dict(self.named_parameters())
params_checker = {k: False for k, v in params_dict.items()}
......
......@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip,
is_port_available,
is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address,
nullable_str,
)
......@@ -492,10 +493,15 @@ class ServerArgs:
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
if self.enable_triton_kernel_moe:
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True
......
......@@ -2961,3 +2961,8 @@ class ConcurrentCounter:
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
"""
self.wait_for(lambda count: count == 0)
@lru_cache(maxsize=1)
def is_triton_kernels_available() -> bool:
return importlib.util.find_spec("triton_kernels") is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment