Unverified Commit 65f09131 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

refactor qwen moe code, use communicator to support tp+dp (#6581)

parent fc419b62
......@@ -95,6 +95,7 @@ from sglang.srt.utils import (
get_int_env_var,
is_cuda,
is_hip,
is_non_idle_and_non_empty,
log_info_on_rank0,
)
......@@ -206,14 +207,6 @@ class MoEGate(nn.Module):
return logits
def is_non_idle_and_non_empty(forward_mode, hidden_states):
return (
(forward_mode is not None)
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
)
class DeepseekV2MoE(nn.Module):
def __init__(
......
......@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
......@@ -49,7 +50,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -114,22 +115,22 @@ class Qwen2MoeMLP(nn.Module):
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(
self,
layer_id: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}."
)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
self.experts = get_moe_impl_class()(
layer_id=self.layer_id,
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
......@@ -159,7 +160,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None
......@@ -276,19 +279,6 @@ class Qwen2MoeAttention(nn.Module):
return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class Qwen2MoeDecoderLayer(nn.Module):
def __init__(
self,
......@@ -298,6 +288,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
......@@ -322,16 +313,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
self.info = self._compute_info(config, layer_id=layer_id)
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
self.input_is_scattered = (
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
# Qwen2MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True
is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
if self.info.is_sparse:
if self.is_layer_sparse:
self.mlp = Qwen2MoeSparseMoeBlock(
layer_id=layer_id,
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
......@@ -348,27 +343,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@staticmethod
def _enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int):
# WARN: Qwen2MOE has no dense_layer, it is only for compatibility.
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
is_sparse = (layer_id not in mlp_only_layers) and (
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
)
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
else _FFNInputMode.FULL
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward(
self,
......@@ -377,108 +356,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
)
else:
raise NotImplementedError
def forward_ffn_with_full_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
# TODO extract this bugfix
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
# TODO extract this bugfix
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
elif hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
def forward_ffn_with_scattered_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
# Self Attention
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
......@@ -486,47 +368,15 @@ class Qwen2MoeDecoderLayer(nn.Module):
forward_batch=forward_batch,
)
if self.attn_tp_size != 1:
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
if not (
self._enable_moe_dense_fully_dp()
and (not self.info.is_sparse)
and hidden_states.shape[0] == 0
):
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
hidden_states = self.mlp(hidden_states, forward_batch)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
......
......@@ -38,6 +38,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
attn_tp_reduce_scatter,
......@@ -78,7 +79,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.utils import DeepEPMode, add_prefix
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
Qwen3MoeConfig = None
......@@ -150,13 +151,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
)
def forward(
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
if not global_server_args_dict["enable_deepep_moe"]:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_mode)
return self.forward_deepep(hidden_states, forward_batch)
def get_moe_weights(self):
return [
......@@ -180,13 +181,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return final_hidden_states.view(num_tokens, hidden_dim)
def forward_deepep(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
):
forward_mode = forward_batch.forward_mode
if is_non_idle_and_non_empty(forward_mode, hidden_states):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
......@@ -356,19 +354,6 @@ class Qwen3MoeAttention(nn.Module):
return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
......@@ -378,6 +363,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
......@@ -408,15 +394,18 @@ class Qwen3MoeDecoderLayer(nn.Module):
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
self.info = self._compute_info(config, layer_id=layer_id)
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
self.input_is_scattered = (
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
# Qwen3MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True
is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
if self.info.is_sparse:
if self.is_layer_sparse:
self.mlp = Qwen3MoeSparseMoeBlock(
layer_id=self.layer_id,
config=config,
......@@ -436,26 +425,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps
)
@staticmethod
def _enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int):
# WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
is_sparse = (layer_id not in mlp_only_layers) and (
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
else _FFNInputMode.FULL
)
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward(
self,
......@@ -464,105 +438,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
)
else:
raise NotImplementedError
def forward_ffn_with_full_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
# TODO extract this bugfix
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
elif hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
def forward_ffn_with_scattered_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
# Self Attention
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
......@@ -570,47 +450,15 @@ class Qwen3MoeDecoderLayer(nn.Module):
forward_batch=forward_batch,
)
if self.attn_tp_size != 1:
if self.input_is_scattered:
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
if not (
self._enable_moe_dense_fully_dp()
and (not self.info.is_sparse)
and hidden_states.shape[0] == 0
):
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
hidden_states = self.mlp(hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
......
......@@ -2026,6 +2026,14 @@ class DeepEPMode(Enum):
return DeepEPMode.normal
def is_non_idle_and_non_empty(forward_mode, hidden_states):
return (
(forward_mode is not None)
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
)
def fast_topk(values, topk, dim):
if topk == 1:
# Use max along the specified dimension to get both value and index
......
......@@ -146,7 +146,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.62)
def test_logprob(self):
prompt = "The capital of taiwan is "
prompt = "The capital of france is "
response = requests.post(
self.lb_url + "/generate",
json={
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment