Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a38376fa
Unverified
Commit
a38376fa
authored
May 25, 2025
by
fzyzcjy
Committed by
GitHub
May 24, 2025
Browse files
Refactor attention into multiple stages (#6477)
parent
7a5e6ce1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
26 deletions
+121
-26
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+117
-24
python/sglang/srt/operations_strategy.py
python/sglang/srt/operations_strategy.py
+4
-2
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
a38376fa
...
@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
else
:
else
:
return
_dispatch_mla_subtype
()
return
_dispatch_mla_subtype
()
def
op_prepare
(
self
,
state
):
state
.
attn_intermediate_state
=
self
.
forward_prepare
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
)
def
op_core
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
forward_core
(
state
.
pop
(
"attn_intermediate_state"
)
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
s
=
self
.
forward_prepare
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
return
self
.
forward_core
(
s
)
def
forward_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
):
if
hidden_states
.
shape
[
0
]
==
0
:
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
assert
(
not
self
.
o_proj
.
reduce_results
not
self
.
o_proj
.
reduce_results
),
"short-circuiting allreduce will lead to hangs"
),
"short-circuiting allreduce will lead to hangs"
return
hidden_states
return
hidden_states
,
None
,
forward_batch
,
None
attn_forward_method
=
self
.
dispatch_attn_forward_method
(
forward_batch
)
attn_forward_method
=
self
.
dispatch_attn_forward_method
(
forward_batch
)
if
attn_forward_method
==
AttnForwardMethod
.
MHA
:
if
attn_forward_method
==
AttnForwardMethod
.
MHA
:
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
inner_state
=
self
.
forward_normal_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv
(
inner_state
=
self
.
forward_normal_chunked_kv
_prepare
(
positions
,
hidden_states
,
forward_batch
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb
(
inner_state
=
self
.
forward_absorb
_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
return
self
.
forward_absorb_fused_mla_rope
(
inner_state
=
self
.
forward_absorb_fused_mla_rope
_prepare
(
positions
,
hidden_states
,
forward_batch
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
None
,
attn_forward_method
,
forward_batch
,
inner_state
def
forward_normal
(
def
forward_core
(
self
,
intermediate_state
):
hidden_states
,
attn_forward_method
,
forward_batch
,
inner_state
=
(
intermediate_state
)
if
inner_state
is
None
:
return
hidden_states
if
attn_forward_method
==
AttnForwardMethod
.
MHA
:
return
self
.
forward_normal_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
return
self
.
forward_absorb_fused_mla_rope_core
(
*
inner_state
)
else
:
raise
NotImplementedError
def
forward_normal_prepare
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
zero_allocator
:
BumpAllocator
,
):
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
...
@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
)
return
q
,
k
,
v
,
forward_batch
def
forward_normal_core
(
self
,
q
,
k
,
v
,
forward_batch
):
attn_output
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
attn_output
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
attn_output
=
attn_output
.
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
attn_output
=
attn_output
.
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
def
forward_absorb
(
def
forward_absorb
_prepare
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
if
self
.
q_lora_rank
is
not
None
:
...
@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
return
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
def
forward_absorb_core
(
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
):
if
self
.
attention_backend
==
"fa3"
or
self
.
attention_backend
==
"flashinfer"
:
if
self
.
attention_backend
==
"fa3"
or
self
.
attention_backend
==
"flashinfer"
:
attn_output
=
self
.
attn_mqa
(
attn_output
=
self
.
attn_mqa
(
q_nope_out
,
k_nope
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
,
k_rope
=
k_pe
q_nope_out
,
k_nope
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
,
k_rope
=
k_pe
...
@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return
output
return
output
def
forward_absorb_fused_mla_rope
(
def
forward_absorb_fused_mla_rope
_prepare
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
enable_rope_fusion
=
(
enable_rope_fusion
=
(
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
)
)
...
@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
)
)
val_cache_buf
=
key_cache_buf
[...,
:
self
.
kv_lora_rank
]
val_cache_buf
=
key_cache_buf
[...,
:
self
.
kv_lora_rank
]
return
(
q_input
,
key_cache_buf
,
val_cache_buf
,
attn_output
,
kv_indptr
,
kv_indices
,
k_pe_output
,
cos_sin_cache
,
positions
,
attn_logits
,
num_kv_split
,
sm_scale
,
enable_rope_fusion
,
k_input
,
forward_batch
,
zero_allocator
,
)
def
forward_absorb_fused_mla_rope_core
(
self
,
q_input
,
key_cache_buf
,
val_cache_buf
,
attn_output
,
kv_indptr
,
kv_indices
,
k_pe_output
,
cos_sin_cache
,
positions
,
attn_logits
,
num_kv_split
,
sm_scale
,
enable_rope_fusion
,
k_input
,
forward_batch
,
zero_allocator
,
):
decode_attention_fwd_grouped_rope
(
decode_attention_fwd_grouped_rope
(
q_input
,
q_input
,
key_cache_buf
,
key_cache_buf
,
...
@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return
accum_output
return
accum_output
def
forward_normal_chunked_kv
(
def
forward_normal_chunked_kv
_prepare
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
zero_allocator
:
BumpAllocator
,
):
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
# To avoid this, we split the kv cache into chunks and process them one after another.
# To avoid this, we split the kv cache into chunks and process them one after another.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
...
@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
)
return
q
,
k
,
v
,
forward_batch
def
forward_normal_chunked_kv_core
(
self
,
q
,
k
,
v
,
forward_batch
):
# Do mha for extended part without prefix
# Do mha for extended part without prefix
forward_batch
.
set_attn_attend_prefix_cache
(
False
)
forward_batch
.
set_attn_attend_prefix_cache
(
False
)
attn_output
,
lse
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
attn_output
,
lse
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
...
@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
)
)
)
)
def
op_attn
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
self_attn
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
)
def
op_comm_prepare_mlp
(
self
,
state
):
def
op_comm_prepare_mlp
(
self
,
state
):
state
.
hidden_states_mlp_input
,
state
.
residual_after_comm_pre_mlp
=
(
state
.
hidden_states_mlp_input
,
state
.
residual_after_comm_pre_mlp
=
(
self
.
layer_communicator
.
prepare_mlp
(
self
.
layer_communicator
.
prepare_mlp
(
...
...
python/sglang/srt/operations_strategy.py
View file @
a38376fa
...
@@ -7,7 +7,8 @@ def compute_layer_operations(
...
@@ -7,7 +7,8 @@ def compute_layer_operations(
if
not
layer
.
is_layer_sparse
:
if
not
layer
.
is_layer_sparse
:
return
[
return
[
layer
.
op_comm_prepare_attn
,
layer
.
op_comm_prepare_attn
,
layer
.
op_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_mlp
,
layer
.
op_mlp
,
layer
.
op_comm_postprocess_layer
,
layer
.
op_comm_postprocess_layer
,
...
@@ -16,7 +17,8 @@ def compute_layer_operations(
...
@@ -16,7 +17,8 @@ def compute_layer_operations(
# Will add TBO operation orders here
# Will add TBO operation orders here
return
[
return
[
layer
.
op_comm_prepare_attn
,
layer
.
op_comm_prepare_attn
,
layer
.
op_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_shared_experts
,
layer
.
mlp
.
op_shared_experts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment