Unverified Commit 53dcf388 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Introduce moe_dense_tp_size to fix dense layer errors in DeepSeek V3 + 4x8xH100 (#4836)

parent 1effba4c
...@@ -78,6 +78,7 @@ global_server_args_dict = { ...@@ -78,6 +78,7 @@ global_server_args_dict = {
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"disable_radix_cache": ServerArgs.disable_radix_cache, "disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"chunked_prefill_size": ServerArgs.chunked_prefill_size, "chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
......
...@@ -159,6 +159,7 @@ class ModelRunner: ...@@ -159,6 +159,7 @@ class ModelRunner:
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"disable_radix_cache": server_args.disable_radix_cache, "disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": server_args.moe_dense_tp_size,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion, "n_share_experts_fusion": server_args.n_share_experts_fusion,
......
...@@ -1066,12 +1066,18 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1066,12 +1066,18 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
else: else:
if self._enable_moe_dense_fully_dp():
mlp_tp_rank, mlp_tp_size = 0, 1
else:
mlp_tp_rank, mlp_tp_size = None, None
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
tp_rank=mlp_tp_rank,
tp_size=mlp_tp_size,
) )
self.input_is_scattered = ( self.input_is_scattered = (
...@@ -1084,6 +1090,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1084,6 +1090,10 @@ class DeepseekV2DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
@staticmethod
def _enable_moe_dense_fully_dp():
return global_server_args_dict["moe_dense_tp_size"] == 1
@staticmethod @staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool): def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
is_sparse = is_nextn or ( is_sparse = is_nextn or (
...@@ -1094,6 +1104,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1094,6 +1104,7 @@ class DeepseekV2DecoderLayer(nn.Module):
ffn_input_mode = ( ffn_input_mode = (
_FFNInputMode.SCATTERED _FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse) if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
else _FFNInputMode.FULL else _FFNInputMode.FULL
) )
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
...@@ -1240,7 +1251,12 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1240,7 +1251,12 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual hidden_states, residual
) )
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) if not (
self._enable_moe_dense_fully_dp()
and (not self.info.is_sparse)
and hidden_states.shape[0] == 0
):
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1: if self.is_last_layer and self.attn_tp_size != 1:
hidden_states += residual hidden_states += residual
......
...@@ -181,6 +181,7 @@ class ServerArgs: ...@@ -181,6 +181,7 @@ class ServerArgs:
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
moe_dense_tp_size: Optional[int] = None
n_share_experts_fusion: int = 0 n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
...@@ -252,6 +253,11 @@ class ServerArgs: ...@@ -252,6 +253,11 @@ class ServerArgs:
assert self.chunked_prefill_size % self.page_size == 0 assert self.chunked_prefill_size % self.page_size == 0
assert self.moe_dense_tp_size in {
1,
None,
}, f"moe_dense_tp_size only support 1 and None currently"
if self.attention_backend == "flashmla": if self.attention_backend == "flashmla":
logger.warning( logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64." "FlashMLA only supports a page_size of 64, change page_size to 64."
...@@ -1101,6 +1107,12 @@ class ServerArgs: ...@@ -1101,6 +1107,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enabling DeepEP MoE implementation for EP MoE.", help="Enabling DeepEP MoE implementation for EP MoE.",
) )
parser.add_argument(
"--moe-dense-tp-size",
type=int,
default=ServerArgs.moe_dense_tp_size,
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
)
parser.add_argument( parser.add_argument(
"--deepep-mode", "--deepep-mode",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment