Unverified Commit 65f09131 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

refactor qwen moe code, use communicator to support tp+dp (#6581)

parent fc419b62
...@@ -95,6 +95,7 @@ from sglang.srt.utils import ( ...@@ -95,6 +95,7 @@ from sglang.srt.utils import (
get_int_env_var, get_int_env_var,
is_cuda, is_cuda,
is_hip, is_hip,
is_non_idle_and_non_empty,
log_info_on_rank0, log_info_on_rank0,
) )
...@@ -206,14 +207,6 @@ class MoEGate(nn.Module): ...@@ -206,14 +207,6 @@ class MoEGate(nn.Module):
return logits return logits
def is_non_idle_and_non_empty(forward_mode, hidden_states):
return (
(forward_mode is not None)
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
......
...@@ -32,6 +32,7 @@ from sglang.srt.distributed import ( ...@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather, attn_tp_all_gather,
attn_tp_reduce_scatter, attn_tp_reduce_scatter,
...@@ -49,7 +50,7 @@ from sglang.srt.layers.linear import ( ...@@ -49,7 +50,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -114,22 +115,22 @@ class Qwen2MoeMLP(nn.Module): ...@@ -114,22 +115,22 @@ class Qwen2MoeMLP(nn.Module):
class Qwen2MoeSparseMoeBlock(nn.Module): class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
layer_id: int,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}." f"the number of experts {config.num_experts}."
) )
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE self.experts = get_moe_impl_class()(
layer_id=self.layer_id,
self.experts = MoEImpl(
num_experts=config.num_experts, num_experts=config.num_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -159,7 +160,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -159,7 +160,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.shared_expert = None self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
shared_output = None shared_output = None
...@@ -276,19 +279,6 @@ class Qwen2MoeAttention(nn.Module): ...@@ -276,19 +279,6 @@ class Qwen2MoeAttention(nn.Module):
return output return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class Qwen2MoeDecoderLayer(nn.Module): class Qwen2MoeDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -298,6 +288,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -298,6 +288,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
...@@ -322,16 +313,20 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -322,16 +313,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size() self.local_dp_size = get_local_attention_dp_size()
self.info = self._compute_info(config, layer_id=layer_id) # Qwen2MoE all layers are sparse and have no nextn now
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) self.is_layer_sparse = True
self.input_is_scattered = ( is_previous_layer_sparse = True
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
) )
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
if self.info.is_sparse: if self.is_layer_sparse:
self.mlp = Qwen2MoeSparseMoeBlock( self.mlp = Qwen2MoeSparseMoeBlock(
layer_id=layer_id,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
...@@ -348,27 +343,11 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -348,27 +343,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
self.layer_communicator = LayerCommunicator(
@staticmethod layer_scatter_modes=self.layer_scatter_modes,
def _enable_moe_dense_fully_dp(): input_layernorm=self.input_layernorm,
return global_server_args_dict["moe_dense_tp_size"] == 1 post_attention_layernorm=self.post_attention_layernorm,
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int):
# WARN: Qwen2MOE has no dense_layer, it is only for compatibility.
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
is_sparse = (layer_id not in mlp_only_layers) and (
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
)
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
else _FFNInputMode.FULL
) )
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward( def forward(
self, self,
...@@ -377,108 +356,11 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -377,108 +356,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
)
else:
raise NotImplementedError
def forward_ffn_with_full_input( hidden_states, residual = self.layer_communicator.prepare_attn(
self, hidden_states, residual, forward_batch
positions: torch.Tensor, )
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
# TODO extract this bugfix
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
# TODO extract this bugfix
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
elif hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states)
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
def forward_ffn_with_scattered_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
# Self Attention
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
...@@ -486,47 +368,15 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -486,47 +368,15 @@ class Qwen2MoeDecoderLayer(nn.Module):
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if self.attn_tp_size != 1: hidden_states, residual = self.layer_communicator.prepare_mlp(
if self.input_is_scattered: hidden_states, residual, forward_batch
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) )
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
if not ( hidden_states = self.mlp(hidden_states, forward_batch)
self._enable_moe_dense_fully_dp()
and (not self.info.is_sparse)
and hidden_states.shape[0] == 0
):
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1: hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states += residual hidden_states, residual, forward_batch
residual = None )
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
return hidden_states, residual return hidden_states, residual
......
...@@ -38,6 +38,7 @@ from sglang.srt.distributed import ( ...@@ -38,6 +38,7 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
attn_tp_all_gather, attn_tp_all_gather,
attn_tp_reduce_scatter, attn_tp_reduce_scatter,
...@@ -78,7 +79,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -78,7 +79,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.utils import DeepEPMode, add_prefix from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
Qwen3MoeConfig = None Qwen3MoeConfig = None
...@@ -150,13 +151,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -150,13 +151,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
) )
def forward( def forward(
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["enable_deepep_moe"]: if not global_server_args_dict["enable_deepep_moe"]:
return self.forward_normal(hidden_states) return self.forward_normal(hidden_states)
else: else:
return self.forward_deepep(hidden_states, forward_mode) return self.forward_deepep(hidden_states, forward_batch)
def get_moe_weights(self): def get_moe_weights(self):
return [ return [
...@@ -180,13 +181,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -180,13 +181,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
def forward_deepep( def forward_deepep(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor: ) -> torch.Tensor:
if ( forward_mode = forward_batch.forward_mode
forward_mode is not None if is_non_idle_and_non_empty(forward_mode, hidden_states):
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
...@@ -356,19 +354,6 @@ class Qwen3MoeAttention(nn.Module): ...@@ -356,19 +354,6 @@ class Qwen3MoeAttention(nn.Module):
return output return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class Qwen3MoeDecoderLayer(nn.Module): class Qwen3MoeDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -378,6 +363,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -378,6 +363,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
...@@ -408,15 +394,18 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -408,15 +394,18 @@ class Qwen3MoeDecoderLayer(nn.Module):
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size() self.local_dp_size = get_local_attention_dp_size()
self.info = self._compute_info(config, layer_id=layer_id) # Qwen3MoE all layers are sparse and have no nextn now
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) self.is_layer_sparse = True
self.input_is_scattered = ( is_previous_layer_sparse = True
layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
) )
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
if self.info.is_sparse: if self.is_layer_sparse:
self.mlp = Qwen3MoeSparseMoeBlock( self.mlp = Qwen3MoeSparseMoeBlock(
layer_id=self.layer_id, layer_id=self.layer_id,
config=config, config=config,
...@@ -436,26 +425,11 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -436,26 +425,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
@staticmethod self.layer_communicator = LayerCommunicator(
def _enable_moe_dense_fully_dp(): layer_scatter_modes=self.layer_scatter_modes,
return global_server_args_dict["moe_dense_tp_size"] == 1 input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int):
# WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
is_sparse = (layer_id not in mlp_only_layers) and (
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
) )
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
else _FFNInputMode.FULL
)
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward( def forward(
self, self,
...@@ -464,105 +438,11 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -464,105 +438,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
)
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
)
else:
raise NotImplementedError
def forward_ffn_with_full_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
# TODO extract this bugfix
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
elif hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
def forward_ffn_with_scattered_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
if self.attn_tp_size != 1 and self.input_is_scattered: hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, local_hidden_states = ( hidden_states, residual, forward_batch
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], )
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
# Self Attention
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
...@@ -570,47 +450,15 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -570,47 +450,15 @@ class Qwen3MoeDecoderLayer(nn.Module):
forward_batch=forward_batch, forward_batch=forward_batch,
) )
if self.attn_tp_size != 1: hidden_states, residual = self.layer_communicator.prepare_mlp(
if self.input_is_scattered: hidden_states, residual, forward_batch
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) )
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
if self.attn_tp_rank == 0:
hidden_states += residual
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
hidden_states = tensor_list[self.attn_tp_rank]
attn_tp_reduce_scatter(hidden_states, tensor_list)
residual = hidden_states
if hidden_states.shape[0] != 0:
hidden_states = self.post_attention_layernorm(hidden_states)
else:
if hidden_states.shape[0] != 0:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
if not ( hidden_states = self.mlp(hidden_states, forward_batch)
self._enable_moe_dense_fully_dp()
and (not self.info.is_sparse) hidden_states, residual = self.layer_communicator.postprocess_layer(
and hidden_states.shape[0] == 0 hidden_states, residual, forward_batch
): )
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual
residual = None
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
attn_tp_all_gather(
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
)
return hidden_states, residual return hidden_states, residual
......
...@@ -2026,6 +2026,14 @@ class DeepEPMode(Enum): ...@@ -2026,6 +2026,14 @@ class DeepEPMode(Enum):
return DeepEPMode.normal return DeepEPMode.normal
def is_non_idle_and_non_empty(forward_mode, hidden_states):
return (
(forward_mode is not None)
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
)
def fast_topk(values, topk, dim): def fast_topk(values, topk, dim):
if topk == 1: if topk == 1:
# Use max along the specified dimension to get both value and index # Use max along the specified dimension to get both value and index
......
...@@ -146,7 +146,7 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -146,7 +146,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.62) self.assertGreater(metrics["accuracy"], 0.62)
def test_logprob(self): def test_logprob(self):
prompt = "The capital of taiwan is " prompt = "The capital of france is "
response = requests.post( response = requests.post(
self.lb_url + "/generate", self.lb_url + "/generate",
json={ json={
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment