Unverified Commit 8beb356f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepSeek decoder layer branches (#5205)

parent c776234b
......@@ -18,7 +18,8 @@
import logging
import os
from enum import IntEnum, auto
from dataclasses import dataclass
from enum import Enum, IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
......@@ -28,6 +29,7 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
......@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module):
)
self.act_fn = SiluAndMul()
def forward(self, x):
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
......@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
......@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module):
is_nextn: bool = False,
prefix: str = "",
) -> None:
def is_sparse_layer(l: int):
return (
config.n_routed_experts is not None
and l >= config.first_k_dense_replace
and l % config.moe_layer_freq == 0
)
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
......@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("self_attn", prefix),
)
if is_nextn or is_sparse_layer(layer_id):
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
previous_layer_info = self._compute_info(
config, layer_id=layer_id - 1, is_nextn=False
)
if self.info.is_sparse:
self.mlp = DeepseekV2MoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = True
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
......@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.is_sparse = False
self.input_is_scattered = (
is_sparse_layer(layer_id - 1)
and global_server_args_dict["enable_deepep_moe"]
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
......@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps
)
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
is_sparse = is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
)
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
else _FFNInputMode.FULL
)
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward(
self,
positions: torch.Tensor,
......@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
return self.forward_deepep(
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual
)
else:
return self.forward_normal(
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual
)
else:
raise NotImplementedError
def forward_normal(
def forward_ffn_with_full_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
......@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual
def forward_deepep(
def forward_ffn_with_scattered_input(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
......@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment