Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b19df4b
Unverified
Commit
1b19df4b
authored
May 20, 2025
by
fzyzcjy
Committed by
GitHub
May 19, 2025
Browse files
Refactor communication logic of DeepSeek for extensibility and understandability (#6321)
parent
f0653886
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
496 additions
and
188 deletions
+496
-188
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+451
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+45
-188
No files found.
python/sglang/srt/layers/communicator.py
0 → 100644
View file @
1b19df4b
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
Optional
,
Tuple
import
torch.distributed
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_reduce_scatter
,
dp_gather_partial
,
dp_scatter
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
class
ScatterMode
(
Enum
):
SCATTERED
=
auto
()
TP_ATTN_FULL
=
auto
()
FULL
=
auto
()
@
dataclass
class
_LayerModeComputationContext
:
num_layers
:
int
layer_id
:
int
is_layer_sparse
:
bool
is_previous_layer_sparse
:
Optional
[
bool
]
def
previous_layer
(
self
):
assert
self
.
is_previous_layer_sparse
is
not
None
return
_LayerModeComputationContext
(
layer_id
=
self
.
layer_id
-
1
,
is_layer_sparse
=
self
.
is_previous_layer_sparse
,
is_previous_layer_sparse
=
None
,
num_layers
=
self
.
num_layers
,
)
@
dataclass
class
LayerScatterModes
:
layer_input_mode
:
ScatterMode
attn_mode
:
ScatterMode
# Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed
mlp_mode
:
ScatterMode
middle_residual_mode
:
ScatterMode
layer_output_mode
:
ScatterMode
@
classmethod
def
init_new
(
cls
,
**
kwargs
):
context
=
_LayerModeComputationContext
(
**
kwargs
)
return
cls
(
layer_input_mode
=
cls
.
_compute_layer_input_mode
(
context
),
attn_mode
=
ScatterMode
.
TP_ATTN_FULL
,
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
),
middle_residual_mode
=
cls
.
_compute_middle_residual_mode
(
context
),
layer_output_mode
=
cls
.
_compute_layer_output_mode
(
context
),
)
@
classmethod
def
_compute_layer_input_mode
(
cls
,
context
:
_LayerModeComputationContext
):
if
context
.
layer_id
==
0
:
return
ScatterMode
.
TP_ATTN_FULL
return
cls
.
_compute_layer_output_mode
(
context
.
previous_layer
())
@
classmethod
def
_compute_mlp_mode
(
cls
,
context
:
_LayerModeComputationContext
):
if
context
.
is_layer_sparse
:
return
(
ScatterMode
.
SCATTERED
if
global_server_args_dict
[
"enable_deepep_moe"
]
else
ScatterMode
.
FULL
)
else
:
return
(
ScatterMode
.
SCATTERED
if
enable_moe_dense_fully_dp
()
else
ScatterMode
.
FULL
)
@
classmethod
def
_compute_middle_residual_mode
(
cls
,
context
:
_LayerModeComputationContext
):
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
)
if
mlp_mode
==
ScatterMode
.
SCATTERED
:
return
ScatterMode
.
SCATTERED
if
mlp_mode
==
ScatterMode
.
FULL
:
return
ScatterMode
.
TP_ATTN_FULL
raise
NotImplementedError
@
classmethod
def
_compute_layer_output_mode
(
cls
,
context
:
_LayerModeComputationContext
):
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
)
if
context
.
layer_id
==
context
.
num_layers
-
1
:
return
ScatterMode
.
TP_ATTN_FULL
if
mlp_mode
==
ScatterMode
.
SCATTERED
:
return
ScatterMode
.
SCATTERED
if
mlp_mode
==
ScatterMode
.
FULL
:
return
ScatterMode
.
TP_ATTN_FULL
raise
NotImplementedError
def
enable_moe_dense_fully_dp
():
return
global_server_args_dict
[
"moe_dense_tp_size"
]
==
1
class
LayerCommunicator
:
def
__init__
(
self
,
layer_scatter_modes
:
LayerScatterModes
,
input_layernorm
:
torch
.
nn
.
Module
,
post_attention_layernorm
:
torch
.
nn
.
Module
,
):
self
.
layer_scatter_modes
=
layer_scatter_modes
self
.
input_layernorm
=
input_layernorm
self
.
post_attention_layernorm
=
post_attention_layernorm
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
local_attn_dp_size
=
get_local_attention_dp_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
process_group_sizes
=
{
ScatterMode
.
SCATTERED
:
1
,
ScatterMode
.
TP_ATTN_FULL
:
self
.
attn_tp_size
,
ScatterMode
.
FULL
:
self
.
tp_size
,
}
def
prepare_attn
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
else
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
_communicate_simple
(
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
input_mode
=
self
.
layer_scatter_modes
.
layer_input_mode
,
output_mode
=
self
.
layer_scatter_modes
.
attn_mode
,
context
=
self
.
_compute_context
(
forward_batch
),
)
return
hidden_states
,
residual
def
prepare_mlp
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
return
_communicate_with_all_reduce_and_layer_norm
(
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
forward_batch
,
hidden_states_input_mode
=
self
.
layer_scatter_modes
.
attn_mode
,
residual_input_mode
=
self
.
layer_scatter_modes
.
layer_input_mode
,
hidden_states_output_mode
=
self
.
layer_scatter_modes
.
mlp_mode
,
residual_output_mode
=
self
.
layer_scatter_modes
.
middle_residual_mode
,
layernorm
=
self
.
post_attention_layernorm
,
context
=
self
.
_compute_context
(
forward_batch
),
)
def
postprocess_layer
(
self
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
):
return
_communicate_summable_tensor_pair
(
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
forward_batch
,
hidden_states_input_mode
=
self
.
layer_scatter_modes
.
mlp_mode
,
residual_input_mode
=
self
.
layer_scatter_modes
.
middle_residual_mode
,
output_mode
=
self
.
layer_scatter_modes
.
layer_output_mode
,
context
=
self
.
_compute_context
(
forward_batch
),
)
def
_compute_context
(
self
,
forward_batch
:
ForwardBatch
):
return
_Context
(
num_tokens_of_mode
=
_compute_num_tokens_of_mode
(
forward_batch
,
attn_tp_rank
=
self
.
attn_tp_rank
,
attn_tp_size
=
self
.
attn_tp_size
,
),
process_group_sizes
=
self
.
process_group_sizes
,
attn_tp_rank
=
self
.
attn_tp_rank
,
attn_tp_size
=
self
.
attn_tp_size
,
local_attn_dp_size
=
self
.
local_attn_dp_size
,
tp_size
=
self
.
tp_size
,
)
def
_compute_num_tokens_of_mode
(
forward_batch
:
ForwardBatch
,
attn_tp_rank
:
int
,
attn_tp_size
:
int
):
tp_attn_full_num_tokens
=
forward_batch
.
input_ids
.
shape
[
0
]
return
{
ScatterMode
.
SCATTERED
:
_torch_tensor_split_len
(
tp_attn_full_num_tokens
,
attn_tp_size
,
attn_tp_rank
),
ScatterMode
.
TP_ATTN_FULL
:
tp_attn_full_num_tokens
,
ScatterMode
.
FULL
:
(
forward_batch
.
gathered_buffer
.
shape
[
0
]
if
global_server_args_dict
[
"enable_dp_attention"
]
else
forward_batch
.
input_ids
.
shape
[
0
]
),
}
def
_torch_tensor_split_len
(
tensor_len
:
int
,
n
:
int
,
output_index
:
int
):
if
output_index
<
int
(
tensor_len
%
n
):
return
int
(
tensor_len
/
n
)
+
1
else
:
return
int
(
tensor_len
/
n
)
@
dataclass
class
_Context
:
num_tokens_of_mode
:
Dict
[
"ScatterMode"
,
int
]
process_group_sizes
:
Dict
[
"ScatterMode"
,
int
]
attn_tp_rank
:
int
attn_tp_size
:
int
local_attn_dp_size
:
int
tp_size
:
int
def
is_same_group_size
(
self
,
a
:
"ScatterMode"
,
b
:
"ScatterMode"
):
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
def
check_shape
(
self
,
x
:
torch
.
Tensor
,
mode
:
ScatterMode
):
if
x
is
None
:
return
actual_num_tokens
=
x
.
shape
[
0
]
expect_num_tokens
=
self
.
num_tokens_of_mode
[
mode
]
assert
(
actual_num_tokens
==
expect_num_tokens
),
f
"
{
actual_num_tokens
=
}
{
expect_num_tokens
=
}
{
mode
=
}
{
x
.
shape
=
}
{
self
.
num_tokens_of_mode
=
}
{
self
.
process_group_sizes
=
}
"
return
x
def
check_shapes
(
self
,
xs
:
Tuple
[
torch
.
Tensor
,
...],
modes
:
Tuple
[
ScatterMode
,
...]
)
->
Tuple
[
torch
.
Tensor
,
...]:
return
tuple
(
[
self
.
check_shape
(
x
,
mode
)
for
x
,
mode
in
zip
(
xs
,
modes
,
strict
=
True
)]
)
def
_communicate_simple
(
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_mode
:
ScatterMode
,
output_mode
:
ScatterMode
,
context
:
_Context
,
)
->
torch
.
Tensor
:
def
_inner
():
nonlocal
hidden_states
if
context
.
is_same_group_size
(
input_mode
,
output_mode
):
return
hidden_states
if
(
input_mode
==
ScatterMode
.
SCATTERED
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
):
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
context
.
attn_tp_size
)),
local_hidden_states
,
)
return
hidden_states
raise
NotImplementedError
(
f
"
{
input_mode
=
}
{
output_mode
=
}
"
)
context
.
check_shape
(
hidden_states
,
input_mode
)
return
context
.
check_shape
(
_inner
(),
output_mode
)
def
_communicate_with_all_reduce_and_layer_norm
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
hidden_states_input_mode
:
ScatterMode
,
residual_input_mode
:
ScatterMode
,
hidden_states_output_mode
:
ScatterMode
,
residual_output_mode
:
ScatterMode
,
forward_batch
:
ForwardBatch
,
layernorm
:
torch
.
nn
.
Module
,
context
:
_Context
,
):
"""Besides communication, needs to
1. All reduce in tp_attn_group on hidden_states
2. Apply layer norm
"""
def
_inner
():
nonlocal
hidden_states
,
residual
if
(
context
.
is_same_group_size
(
hidden_states_input_mode
,
hidden_states_output_mode
)
and
context
.
is_same_group_size
(
residual_input_mode
,
residual_output_mode
)
and
context
.
attn_tp_size
==
1
):
# TODO move these `if shape != 0` into LayerNorm itself
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
return
hidden_states
,
residual
if
(
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
hidden_states_output_mode
==
ScatterMode
.
FULL
)
and
(
residual_output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
):
if
context
.
local_attn_dp_size
!=
1
:
if
context
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
return
hidden_states
,
residual
if
(
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
residual_input_mode
in
[
ScatterMode
.
SCATTERED
,
ScatterMode
.
TP_ATTN_FULL
]
)
and
(
hidden_states_output_mode
==
ScatterMode
.
SCATTERED
)
and
(
residual_output_mode
==
ScatterMode
.
SCATTERED
)
):
tensor_list
=
list
(
hidden_states
.
tensor_split
(
context
.
attn_tp_size
))
hidden_states
=
tensor_list
[
context
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
if
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
:
residual
=
residual
.
tensor_split
(
context
.
attn_tp_size
)[
context
.
attn_tp_rank
]
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
return
hidden_states
,
residual
raise
NotImplementedError
(
f
"
{
hidden_states_input_mode
=
}
{
residual_input_mode
=
}
{
residual_output_mode
=
}
{
residual_output_mode
=
}
"
)
context
.
check_shapes
(
(
hidden_states
,
residual
),
(
hidden_states_input_mode
,
residual_input_mode
)
)
return
context
.
check_shapes
(
_inner
(),
(
hidden_states_output_mode
,
residual_output_mode
)
)
def
_communicate_summable_tensor_pair
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
hidden_states_input_mode
:
ScatterMode
,
residual_input_mode
:
ScatterMode
,
output_mode
:
ScatterMode
,
context
:
_Context
,
):
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
def
_inner
():
nonlocal
hidden_states
,
residual
if
context
.
is_same_group_size
(
hidden_states_input_mode
,
output_mode
)
and
context
.
is_same_group_size
(
residual_input_mode
,
output_mode
):
return
hidden_states
,
residual
if
(
(
hidden_states_input_mode
==
ScatterMode
.
FULL
)
and
(
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
if
(
(
hidden_states_input_mode
==
ScatterMode
.
SCATTERED
)
and
(
residual_input_mode
==
ScatterMode
.
SCATTERED
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
):
hidden_states
+=
residual
residual
=
None
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
context
.
attn_tp_size
)),
local_hidden_states
,
)
return
hidden_states
,
residual
raise
NotImplementedError
(
f
"
{
hidden_states_input_mode
=
}
{
residual_input_mode
=
}
{
output_mode
=
}
"
)
context
.
check_shapes
(
(
hidden_states
,
residual
),
(
hidden_states_input_mode
,
residual_input_mode
)
)
return
context
.
check_shapes
(
_inner
(),
(
output_mode
,
output_mode
))
python/sglang/srt/models/deepseek_v2.py
View file @
1b19df4b
...
...
@@ -18,8 +18,7 @@
import
logging
import
os
from
dataclasses
import
dataclass
from
enum
import
Enum
,
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
torch
...
...
@@ -29,17 +28,17 @@ from tqdm import tqdm
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
parallel_state
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerScatterModes
,
enable_moe_dense_fully_dp
,
)
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_reduce_scatter
,
dp_gather_partial
,
dp_scatter
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_local_attention_dp_size
,
...
...
@@ -52,9 +51,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
EPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
...
...
@@ -72,7 +70,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant
as
int8_block_dequant
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
,
get_rope_wrapper
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
@@ -141,6 +139,8 @@ class DeepseekV2MLP(nn.Module):
tp_size
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
tp_size
=
tp_size
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
...
...
@@ -167,7 +167,10 @@ class DeepseekV2MLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
):
def
forward
(
self
,
x
,
forward_batch
=
None
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
return
x
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
...
...
@@ -1097,19 +1100,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return
output
class
_FFNInputMode
(
Enum
):
# The MLP sublayer requires 1/tp_size tokens as input
SCATTERED
=
auto
()
# The MLP sublayer requires all tokens as input
FULL
=
auto
()
@
dataclass
class
_DecoderLayerInfo
:
is_sparse
:
bool
ffn_input_mode
:
_FFNInputMode
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
def
__init__
(
...
...
@@ -1123,14 +1113,12 @@ class DeepseekV2DecoderLayer(nn.Module):
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
config
=
config
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
enable_dp_attention
=
global_server_args_dict
[
"enable_dp_attention"
]
self
.
layer_id
=
layer_id
self
.
local_dp_size
=
get_local_attention_dp_size
()
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
self_attn
=
DeepseekV2AttentionMLA
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
...
...
@@ -1152,19 +1140,24 @@ class DeepseekV2DecoderLayer(nn.Module):
alt_stream
=
alt_stream
,
)
self
.
info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
,
is_nextn
=
is_nextn
)
previous_layer_info
=
self
.
_compute_info
(
config
,
layer_id
=
layer_id
-
1
,
is_nextn
=
False
self
.
is_layer_sparse
=
self
.
_is_layer_sparse
(
layer_id
,
is_nextn
=
is_nextn
)
is_previous_layer_sparse
=
self
.
_is_layer_sparse
(
layer_id
-
1
,
is_nextn
=
False
)
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
num_layers
=
config
.
num_hidden_layers
,
is_layer_sparse
=
self
.
is_layer_sparse
,
is_previous_layer_sparse
=
is_previous_layer_sparse
,
)
if
self
.
i
nfo
.
is
_sparse
:
if
self
.
i
s_layer
_sparse
:
self
.
mlp
=
DeepseekV2MoE
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
else
:
if
self
.
_
enable_moe_dense_fully_dp
():
if
enable_moe_dense_fully_dp
():
mlp_tp_rank
,
mlp_tp_size
=
0
,
1
else
:
mlp_tp_rank
,
mlp_tp_size
=
None
,
None
...
...
@@ -1178,35 +1171,23 @@ class DeepseekV2DecoderLayer(nn.Module):
tp_size
=
mlp_tp_size
,
)
self
.
input_is_scattered
=
(
layer_id
>
0
and
previous_layer_info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
)
self
.
is_last_layer
=
self
.
layer_id
==
config
.
num_hidden_layers
-
1
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
@
staticmethod
def
_enable_moe_dense_fully_dp
():
return
global_server_args_dict
[
"moe_dense_tp_size"
]
==
1
@
staticmethod
def
_compute_info
(
config
:
PretrainedConfig
,
layer_id
:
int
,
is_nextn
:
bool
):
is_sparse
=
is_nextn
or
(
config
.
n_routed_experts
is
not
None
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
%
config
.
moe_layer_freq
==
0
self
.
layer_communicator
=
LayerCommunicator
(
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
)
ffn_input_mode
=
(
_FFNInputMode
.
SCATTERED
if
(
global_server_args_dict
[
"enable_deepep_moe"
]
and
is_sparse
)
or
(
DeepseekV2DecoderLayer
.
_enable_moe_dense_fully_dp
()
and
not
is_sparse
)
else
_FFNInputMode
.
FULL
def
_is_layer_sparse
(
self
,
layer_id
:
int
,
is_nextn
:
bool
)
->
bool
:
return
is_nextn
or
(
self
.
config
.
n_routed_experts
is
not
None
and
layer_id
>=
self
.
config
.
first_k_dense_replace
and
layer_id
%
self
.
config
.
moe_layer_freq
==
0
)
return
_DecoderLayerInfo
(
is_sparse
=
is_sparse
,
ffn_input_mode
=
ffn_input_mode
)
def
forward
(
self
,
...
...
@@ -1216,114 +1197,10 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
SCATTERED
:
return
self
.
forward_ffn_with_scattered_input
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
elif
self
.
info
.
ffn_input_mode
==
_FFNInputMode
.
FULL
:
return
self
.
forward_ffn_with_full_input
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
else
:
raise
NotImplementedError
def
forward_ffn_with_full_input
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
else
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
assert
not
(
self
.
attn_tp_size
!=
1
and
self
.
input_is_scattered
),
"moe_layer_freq > 1 is not supported when attn_tp_size > 1"
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
# Gather
if
get_tensor_model_parallel_world_size
()
>
1
:
# all gather and all reduce
if
self
.
local_dp_size
!=
1
:
if
self
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
# Fully Connected
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if
self
.
local_dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
def
forward_ffn_with_scattered_input
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
else
:
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
if
self
.
attn_tp_size
!=
1
and
self
.
input_is_scattered
:
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
)),
local_hidden_states
)
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
# Self Attention
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
...
...
@@ -1331,35 +1208,15 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator
=
zero_allocator
,
)
if
self
.
attn_tp_size
!=
1
:
tensor_list
=
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
))
hidden_states
=
tensor_list
[
self
.
attn_tp_rank
]
attn_tp_reduce_scatter
(
hidden_states
,
tensor_list
)
if
not
self
.
input_is_scattered
:
residual
=
residual
.
tensor_split
(
self
.
attn_tp_size
)[
self
.
attn_tp_rank
]
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
hidden_states
,
residual
,
forward_batch
)
if
not
(
self
.
_enable_moe_dense_fully_dp
()
and
(
not
self
.
info
.
is_sparse
)
and
hidden_states
.
shape
[
0
]
==
0
):
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
if
self
.
is_last_layer
and
self
.
attn_tp_size
!=
1
:
hidden_states
+=
residual
residual
=
None
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
attn_tp_all_gather
(
list
(
hidden_states
.
tensor_split
(
self
.
attn_tp_size
)),
local_hidden_states
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment