Unverified Commit 8beb356f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Refactor DeepSeek decoder layer branches (#5205)

parent c776234b
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
import logging import logging
import os import os
from enum import IntEnum, auto from dataclasses import dataclass
from enum import Enum, IntEnum, auto
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
...@@ -28,6 +29,7 @@ from tqdm import tqdm ...@@ -28,6 +29,7 @@ from tqdm import tqdm
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state, parallel_state,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
...@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -146,7 +148,7 @@ class DeepseekV2MLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(self, x, forward_mode: Optional[ForwardMode] = None):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
...@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -999,6 +1001,19 @@ class DeepseekV2AttentionMLA(nn.Module):
return output return output
class _FFNInputMode(Enum):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED = auto()
# The MLP sublayer requires all tokens as input
FULL = auto()
@dataclass
class _DecoderLayerInfo:
is_sparse: bool
ffn_input_mode: _FFNInputMode
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1009,14 +1024,6 @@ class DeepseekV2DecoderLayer(nn.Module):
is_nextn: bool = False, is_nextn: bool = False,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
def is_sparse_layer(l: int):
return (
config.n_routed_experts is not None
and l >= config.first_k_dense_replace
and l % config.moe_layer_freq == 0
)
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1047,13 +1054,17 @@ class DeepseekV2DecoderLayer(nn.Module):
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
if is_nextn or is_sparse_layer(layer_id): self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
previous_layer_info = self._compute_info(
config, layer_id=layer_id - 1, is_nextn=False
)
if self.info.is_sparse:
self.mlp = DeepseekV2MoE( self.mlp = DeepseekV2MoE(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
self.is_sparse = True
else: else:
self.mlp = DeepseekV2MLP( self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
...@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1062,11 +1073,9 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix), prefix=add_prefix("mlp", prefix),
) )
self.is_sparse = False
self.input_is_scattered = ( self.input_is_scattered = (
is_sparse_layer(layer_id - 1) previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
and global_server_args_dict["enable_deepep_moe"]
) )
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
...@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1075,6 +1084,20 @@ class DeepseekV2DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
@staticmethod
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
is_sparse = is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
)
ffn_input_mode = (
_FFNInputMode.SCATTERED
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
else _FFNInputMode.FULL
)
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1082,16 +1105,18 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
if global_server_args_dict["enable_deepep_moe"] and self.is_sparse: if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
return self.forward_deepep( return self.forward_ffn_with_scattered_input(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
else: elif self.info.ffn_input_mode == _FFNInputMode.FULL:
return self.forward_normal( return self.forward_ffn_with_full_input(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
else:
raise NotImplementedError
def forward_normal( def forward_ffn_with_full_input(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1158,7 +1183,7 @@ class DeepseekV2DecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
def forward_deepep( def forward_ffn_with_scattered_input(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1214,6 +1239,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual hidden_states, residual
) )
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
if self.is_last_layer and self.attn_tp_size != 1: if self.is_last_layer and self.attn_tp_size != 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment