Unverified Commit 1d24db83 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Expert Parallelism for GPT-OSS (#8944)

parent 44401358
...@@ -76,6 +76,9 @@ class EPMoE(FusedMoE): ...@@ -76,6 +76,9 @@ class EPMoE(FusedMoE):
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
with_bias: bool = False,
): ):
super().__init__( super().__init__(
num_experts=num_experts, num_experts=num_experts,
...@@ -91,6 +94,9 @@ class EPMoE(FusedMoE): ...@@ -91,6 +94,9 @@ class EPMoE(FusedMoE):
activation=activation, activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input, # apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
with_bias=with_bias,
) )
self.start_expert_id = self.moe_ep_rank * self.num_local_experts self.start_expert_id = self.moe_ep_rank * self.num_local_experts
......
...@@ -319,6 +319,7 @@ def fused_moe_kernel( ...@@ -319,6 +319,7 @@ def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
b_ptr, b_ptr,
bias_ptr,
c_ptr, c_ptr,
a_scale_ptr, a_scale_ptr,
b_scale_ptr, b_scale_ptr,
...@@ -340,6 +341,8 @@ def fused_moe_kernel( ...@@ -340,6 +341,8 @@ def fused_moe_kernel(
stride_be, stride_be,
stride_bk, stride_bk,
stride_bn, stride_bn,
stride_bias_e,
stride_bias_n,
stride_cm, stride_cm,
stride_cn, stride_cn,
stride_asm, stride_asm,
...@@ -449,6 +452,10 @@ def fused_moe_kernel( ...@@ -449,6 +452,10 @@ def fused_moe_kernel(
+ off_experts * stride_be + off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
) )
if bias_ptr is not None:
bias = tl.load(
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
)
if use_int8_w8a16: if use_int8_w8a16:
b_scale_ptrs = ( b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
...@@ -526,18 +533,20 @@ def fused_moe_kernel( ...@@ -526,18 +533,20 @@ def fused_moe_kernel(
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16: if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type) accumulator *= b_scale
elif use_fp8_w8a8 or use_int8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k == 0 or group_n == 0:
accumulator = accumulator.to(compute_type) accumulator *= a_scale * b_scale
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type) if bias_ptr is not None:
else: accumulator += bias
accumulator = accumulator.to(compute_type)
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator *= moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
...@@ -622,6 +631,7 @@ def moe_align_block_size( ...@@ -622,6 +631,7 @@ def moe_align_block_size(
def invoke_fused_moe_kernel( def invoke_fused_moe_kernel(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
bias: Optional[torch.Tensor],
C: torch.Tensor, C: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],
...@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel( ...@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
): ):
assert B_scale is not None and B_scale.ndim == 3 assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
assert bias is None
fused_moe_kernel_gptq_awq[grid]( fused_moe_kernel_gptq_awq[grid](
A, A,
B, B,
...@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel( ...@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
B, B,
bias,
C, C,
A_scale, A_scale,
B_scale, B_scale,
...@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel( ...@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
B.stride(0), B.stride(0),
B.stride(2), B.stride(2),
B.stride(1), B.stride(1),
bias.stride(0) if bias is not None else 0,
bias.stride(1) if bias is not None else 0,
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
...@@ -994,6 +1008,8 @@ def inplace_fused_experts( ...@@ -994,6 +1008,8 @@ def inplace_fused_experts(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
...@@ -1009,6 +1025,8 @@ def inplace_fused_experts( ...@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> None: ) -> None:
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
...@@ -1016,6 +1034,8 @@ def inplace_fused_experts( ...@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
b1,
b2,
True, True,
activation, activation,
apply_router_weight_on_input, apply_router_weight_on_input,
...@@ -1033,6 +1053,8 @@ def inplace_fused_experts( ...@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
block_shape, block_shape,
False, False,
routed_scaling_factor, routed_scaling_factor,
activation_alpha,
swiglu_limit,
) )
...@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake( ...@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
...@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake( ...@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> None: ) -> None:
pass pass
...@@ -1075,6 +1101,8 @@ def outplace_fused_experts( ...@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
...@@ -1091,6 +1119,8 @@ def outplace_fused_experts( ...@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -1098,6 +1128,8 @@ def outplace_fused_experts( ...@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
b1,
b2,
False, False,
activation, activation,
apply_router_weight_on_input, apply_router_weight_on_input,
...@@ -1115,6 +1147,8 @@ def outplace_fused_experts( ...@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
block_shape, block_shape,
no_combine=no_combine, no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
...@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake( ...@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
...@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake( ...@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1157,6 +1195,8 @@ def fused_experts( ...@@ -1157,6 +1195,8 @@ def fused_experts(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -1174,6 +1214,8 @@ def fused_experts( ...@@ -1174,6 +1214,8 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
if inplace: if inplace:
...@@ -1184,6 +1226,8 @@ def fused_experts( ...@@ -1184,6 +1226,8 @@ def fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
b1,
b2,
activation, activation,
apply_router_weight_on_input, apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
...@@ -1199,6 +1243,8 @@ def fused_experts( ...@@ -1199,6 +1243,8 @@ def fused_experts(
a2_scale, a2_scale,
block_shape, block_shape,
routed_scaling_factor, routed_scaling_factor,
activation_alpha,
swiglu_limit,
) )
return hidden_states return hidden_states
else: else:
...@@ -1208,6 +1254,8 @@ def fused_experts( ...@@ -1208,6 +1254,8 @@ def fused_experts(
w2, w2,
topk_weights, topk_weights,
topk_ids, topk_ids,
b1,
b2,
activation, activation,
apply_router_weight_on_input, apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
...@@ -1224,6 +1272,8 @@ def fused_experts( ...@@ -1224,6 +1272,8 @@ def fused_experts(
block_shape, block_shape,
no_combine=no_combine, no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
...@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): ...@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
out.mul_(routed_scaling_factor) out.mul_(routed_scaling_factor)
@torch.compile
def swiglu_with_alpha_and_limit(x, alpha, limit):
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
return gate * torch.sigmoid(gate * alpha) * (up + 1)
def fused_experts_impl( def fused_experts_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -1342,6 +1402,8 @@ def fused_experts_impl( ...@@ -1342,6 +1402,8 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
): ):
padded_size = padding_size padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
...@@ -1353,7 +1415,7 @@ def fused_experts_impl( ...@@ -1353,7 +1415,7 @@ def fused_experts_impl(
else: else:
assert ( assert (
hidden_states.shape[1] == w1.shape[2] - padded_size hidden_states.shape[1] == w1.shape[2] - padded_size
), "Hidden size mismatch" ), f"Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
...@@ -1449,6 +1511,7 @@ def fused_experts_impl( ...@@ -1449,6 +1511,7 @@ def fused_experts_impl(
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
curr_hidden_states, curr_hidden_states,
w1, w1,
b1,
intermediate_cache1, intermediate_cache1,
a1_scale, a1_scale,
w1_scale, w1_scale,
...@@ -1470,13 +1533,24 @@ def fused_experts_impl( ...@@ -1470,13 +1533,24 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
if activation == "silu": if activation == "silu":
if _is_cuda: if activation_alpha is not None:
assert swiglu_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
activation_alpha,
swiglu_limit,
)
elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
vllm_ops.silu_and_mul( vllm_ops.silu_and_mul(
intermediate_cache2, intermediate_cache1.view(-1, N) intermediate_cache2, intermediate_cache1.view(-1, N)
) )
elif activation == "gelu": elif activation == "gelu":
assert (
activation_alpha is None
), "activation_alpha is not supported for gelu"
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
if _is_cuda: if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
...@@ -1489,6 +1563,7 @@ def fused_experts_impl( ...@@ -1489,6 +1563,7 @@ def fused_experts_impl(
invoke_fused_moe_kernel( invoke_fused_moe_kernel(
intermediate_cache2, intermediate_cache2,
w2, w2,
b2,
( (
intermediate_cache3 intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1 if not no_combine and topk_ids.shape[1] != 1
...@@ -1567,6 +1642,8 @@ def fused_moe( ...@@ -1567,6 +1642,8 @@ def fused_moe(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
...@@ -1584,6 +1661,8 @@ def fused_moe( ...@@ -1584,6 +1661,8 @@ def fused_moe(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1594,6 +1673,8 @@ def fused_moe( ...@@ -1594,6 +1673,8 @@ def fused_moe(
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_output (TopKOutput): The top-k output of the experts. - topk_output (TopKOutput): The top-k output of the experts.
- b1 (Optional[torch.Tensor]): Optional bias for w1.
- b2 (Optional[torch.Tensor]): Optional bias for w2.
- inplace (bool): If True, perform the operation in-place. - inplace (bool): If True, perform the operation in-place.
Defaults to False. Defaults to False.
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
...@@ -1615,6 +1696,10 @@ def fused_moe( ...@@ -1615,6 +1696,10 @@ def fused_moe(
a2. a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise - block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization. quantization.
- activation_alpha (Optional[float]): Optional alpha for the activation
function.
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
function.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -1625,6 +1710,8 @@ def fused_moe( ...@@ -1625,6 +1710,8 @@ def fused_moe(
w1, w1,
w2, w2,
topk_output, topk_output,
b1=b1,
b2=b2,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
...@@ -1642,4 +1729,6 @@ def fused_moe( ...@@ -1642,4 +1729,6 @@ def fused_moe(
block_shape=block_shape, block_shape=block_shape,
no_combine=no_combine, no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
...@@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module): ...@@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels, with_bias=with_bias self.use_triton_kernels
) )
else: else:
self.quant_method = quant_config.get_quant_method(self, prefix) self.quant_method = quant_config.get_quant_method(self, prefix)
...@@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module): ...@@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module):
# If we are in EP mode, we need to move the expert map to GPU. # If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None: if self.expert_map_gpu is not None and isinstance(
topk_output, StandardTopKOutput
):
topk_output = topk_output._replace( topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids] topk_ids=self.expert_map_gpu[topk_output.topk_ids]
) )
......
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import triton.language as tl
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -24,6 +25,7 @@ from sglang.srt.utils import ( ...@@ -24,6 +25,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_triton_kernels_available,
log_info_on_rank0, log_info_on_rank0,
next_power_of_2, next_power_of_2,
round_up, round_up,
...@@ -31,7 +33,7 @@ from sglang.srt.utils import ( ...@@ -31,7 +33,7 @@ from sglang.srt.utils import (
) )
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None has_triton_kernels = is_triton_kernels_available()
if is_flashinfer_available(): if is_flashinfer_available():
...@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig): ...@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
): ):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
use_flashinfer = global_server_args_dict.get( return Mxfp4MoEMethod(prefix)
"enable_flashinfer_mxfp4_moe", False
)
return Mxfp4MoEMethod(
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
)
else: else:
raise NotImplementedError("Mxfp4 attention layer is not implemented") raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None return None
...@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__( def __init__(
self, self,
use_triton_kernels: bool = True, prefix: str,
with_bias: bool = True,
use_flashinfer: bool = False,
): ):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__() super().__init__()
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.use_triton_kernels = use_triton_kernels self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.with_bias = with_bias self.with_bias = False
self.use_flashinfer = use_flashinfer self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None
...@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs, **extra_weight_attrs,
): ):
# print(f"hi {self=} create_weights {layer=}")
self.num_experts = num_experts self.num_experts = num_experts
weight_dtype = torch.uint8 weight_dtype = torch.uint8
scale_dtype = torch.uint8 scale_dtype = torch.uint8
self.with_bias = with_bias
mxfp4_block = 32 mxfp4_block = 32
# pad the intermediate size to be a multiple of 2 * mxfp4_block # pad the intermediate size to be a multiple of 2 * mxfp4_block
...@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts, layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad, 2 * intermediate_size_per_partition_after_pad,
hidden_size // 2, hidden_size // 2,
dtype=weight_dtype, dtype=weight_dtype,
...@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts, layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad, 2 * intermediate_size_per_partition_after_pad,
hidden_size // mxfp4_block, hidden_size // mxfp4_block,
dtype=scale_dtype, dtype=scale_dtype,
...@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_bias = torch.nn.Parameter( w13_weight_bias = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts, layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad, 2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16, dtype=torch.bfloat16,
), ),
...@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter( w2_weight = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts, layer.num_local_experts,
hidden_size, hidden_size,
intermediate_size_per_partition_after_pad // 2, intermediate_size_per_partition_after_pad // 2,
dtype=weight_dtype, dtype=weight_dtype,
...@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
num_experts, layer.num_local_experts,
hidden_size, hidden_size,
intermediate_size_per_partition_after_pad // mxfp4_block, intermediate_size_per_partition_after_pad // mxfp4_block,
dtype=scale_dtype, dtype=scale_dtype,
...@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_bias = torch.nn.Parameter( w2_weight_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16), torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
requires_grad=False, requires_grad=False,
) )
layer.register_parameter("w2_weight_bias", w2_weight_bias) layer.register_parameter("w2_weight_bias", w2_weight_bias)
...@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) )
return return
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig if self.use_triton_kernels:
w13_weight_bias = layer.w13_weight_bias.to(torch.float32) from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False) w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False) w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
num_warps = 8 layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( num_warps = 8
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_precision_config = PrecisionConfig( w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex) layer.w13_weight, layer.w13_weight_scale, num_warps
) )
self.w2_precision_config = PrecisionConfig( w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) layer.w2_weight, layer.w2_weight_scale, num_warps
) )
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
self.w13_weight_triton_tensor = w13_weight self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight self.w2_weight_triton_tensor = w2_weight
del layer.w13_weight
del layer.w2_weight
else:
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
# need to delete the original weights to save memory on single GPU w13_weight = upcast_from_mxfp(
del layer.w13_weight layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
del layer.w2_weight )
layer.w13_weight = None w2_weight = upcast_from_mxfp(
layer.w2_weight = None layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
)
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w2_weight_scale
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int): def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
...@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar None, # output1_scale_scalar
None, # output1_scale_gate_scalar None, # output1_scale_gate_scalar
None, # output2_scale_scalar None, # output2_scale_scalar
self.num_experts, layer.num_experts,
top_k, top_k,
None, # n_group None, # n_group
None, # topk_group None, # topk_group
self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
self.num_experts, # local num experts layer.num_local_experts, # local num experts
None, None,
self._get_tile_tokens_dim(x, top_k), self._get_tile_tokens_dim(x, top_k),
1, # routing_method_type, renormalize 1, # routing_method_type, renormalize
...@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return trtllm_gen_output return trtllm_gen_output
if self.use_triton_kernels: if self.use_triton_kernels:
assert (
layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels"
if self.with_bias: if self.with_bias:
# TODO why we do not put weights on layer?
assert layer.w13_weight is None
assert layer.w2_weight is None
return self.triton_kernel_moe_with_bias_forward( return self.triton_kernel_moe_with_bias_forward(
hidden_states=x, hidden_states=x,
w1=self.w13_weight_triton_tensor, w1=self.w13_weight_triton_tensor,
...@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output=topk_output, topk_output=topk_output,
) )
else: else:
raise NotImplementedError() from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
...@@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False): def __init__(self, use_triton_kernels: bool = False):
super().__init__() super().__init__()
self.use_triton_kernels = use_triton_kernels self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias self.with_bias = False
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None
...@@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs, **extra_weight_attrs,
): ):
self.with_bias = with_bias
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
if self.use_triton_kernels: if self.use_triton_kernels:
...@@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
b1=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output, topk_output=topk_output,
inplace=inplace and not no_combine, inplace=inplace and not no_combine,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine, no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
def forward_cpu( def forward_cpu(
......
...@@ -28,6 +28,7 @@ from sglang.srt.distributed import ( ...@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_rank, get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module):
self.activation = config.hidden_act self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702) self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit self.swiglu_limit = config.swiglu_limit
if self.tp_size > config.num_local_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_local_experts}."
)
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]: if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None self.topk = None
...@@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module): ...@@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module):
loaded_params: set[str] = set() loaded_params: set[str] = set()
mxfp4_block = 32 mxfp4_block = 32
tp_rank = get_tensor_model_parallel_rank() moe_tp_rank = get_moe_tensor_parallel_rank()
tp_size = get_tensor_model_parallel_world_size() moe_tp_size = get_moe_tensor_parallel_world_size()
moe_ep_rank = get_moe_expert_parallel_rank()
moe_ep_size = get_moe_expert_parallel_world_size()
intermediate_size = self.config.intermediate_size intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block intermediate_size_block = intermediate_size // mxfp4_block
per_rank_intermediate_size_block = intermediate_size_block // tp_size per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
# Calculate common slicing bounds for current rank # Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size assert self.config.num_local_experts % moe_ep_size == 0
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) moe_num_global_experts = self.config.num_local_experts
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
# Attention heads per rank moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
heads_per_rank = self.config.num_attention_heads // tp_size moe_tp_rank_end = min(
head_start = tp_rank * heads_per_rank (moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
)
num_experts = self.config.num_local_experts moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
for name, weight in weights: for name, weight in weights:
weight = weight.cuda() weight = weight.cuda()
...@@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module): ...@@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module):
# flat weight from (E, 2 * N, block_size, entry_per_block) # flat weight from (E, 2 * N, block_size, entry_per_block)
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
weight = weight.view( weight = weight.view(
num_experts, 2 * intermediate_size, -1 moe_num_global_experts, 2 * intermediate_size, -1
).contiguous() ).contiguous()
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
...,
]
param = params_dict[new_name] param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
...@@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module): ...@@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module):
# same flatten here, but since 2 mx4 value are packed in 1 # same flatten here, but since 2 mx4 value are packed in 1
# uint8, divide by 2 # uint8, divide by 2
weight = weight.view( weight = weight.view(
num_experts, -1, intermediate_size // 2 moe_num_global_experts, -1, intermediate_size // 2
).contiguous() ).contiguous()
narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2] narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
...,
moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
]
param = params_dict[new_name] param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
...@@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module): ...@@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module):
elif "gate_up_proj_scales" in name: elif "gate_up_proj_scales" in name:
# Handle MLP gate and up projection weights scale # Handle MLP gate and up projection weights scale
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale") new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...] narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
...,
]
param = params_dict[new_name] param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
...@@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module): ...@@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module):
# Handle MLP down projection weights # Handle MLP down projection weights
new_name = name.replace("down_proj_scales", "w2_weight_scale") new_name = name.replace("down_proj_scales", "w2_weight_scale")
narrow_weight = weight[ narrow_weight = weight[
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block moe_ep_rank_start:moe_ep_rank_end,
...,
moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
] ]
param = params_dict[new_name] param = params_dict[new_name]
...@@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module): ...@@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module):
# Handle MLP gate and up projection biases # Handle MLP gate and up projection biases
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias") new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] narrow_weight = weight[
moe_ep_rank_start:moe_ep_rank_end,
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
]
param = params_dict[new_name] param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
...@@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module): ...@@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module):
loaded_params.add(new_name) loaded_params.add(new_name)
elif "down_proj_bias" in name: elif "down_proj_bias" in name:
if get_moe_tensor_parallel_rank() != 0: narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
weight = torch.zeros_like(weight) if moe_tp_rank != 0:
narrow_weight = torch.zeros_like(narrow_weight)
# Handle MLP down projection bias # Handle MLP down projection bias
new_name = name.replace("down_proj_bias", "w2_weight_bias") new_name = name.replace("down_proj_bias", "w2_weight_bias")
param = params_dict[new_name] param = params_dict[new_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader( weight_loader(
param, weight, weight_name=new_name, shard_id=None, expert_id=None param,
narrow_weight,
weight_name=new_name,
shard_id=None,
expert_id=None,
) )
loaded_params.add(new_name) loaded_params.add(new_name)
...@@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module): ...@@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module):
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"): ckpt_gate_up_proj_name="gate_up_proj",
expert_params_mapping = ( ckpt_down_proj_name="down_proj",
get_moe_impl_class().make_expert_params_mapping_fused_mxfp4( ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_gate_up_proj_name="gate_up_proj_blocks", ckpt_down_proj_bias_name="down_proj_bias",
ckpt_down_proj_name="down_proj_blocks", )
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
ckpt_gate_up_proj_scale_name="gate_up_proj_scales",
ckpt_down_proj_scale_name="down_proj_scales",
)
)
else:
expert_params_mapping = (
get_moe_impl_class().make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
)
)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
params_checker = {k: False for k, v in params_dict.items()} params_checker = {k: False for k, v in params_dict.items()}
......
...@@ -37,6 +37,7 @@ from sglang.srt.utils import ( ...@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_port_available, is_port_available,
is_remote_url, is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address, is_valid_ipv6_address,
nullable_str, nullable_str,
) )
...@@ -492,10 +493,15 @@ class ServerArgs: ...@@ -492,10 +493,15 @@ class ServerArgs:
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel." "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
) )
else: else:
self.enable_triton_kernel_moe = True if self.enable_triton_kernel_moe:
logger.info( assert (
"Detected GPT-OSS model, enabling triton_kernels MOE kernel." self.ep_size == 1
) ), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
logger.info(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)
self.disable_hybrid_swa_memory = True self.disable_hybrid_swa_memory = True
......
...@@ -2961,3 +2961,8 @@ class ConcurrentCounter: ...@@ -2961,3 +2961,8 @@ class ConcurrentCounter:
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes. other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
""" """
self.wait_for(lambda count: count == 0) self.wait_for(lambda count: count == 0)
@lru_cache(maxsize=1)
def is_triton_kernels_available() -> bool:
return importlib.util.find_spec("triton_kernels") is not None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment