Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a38376fa
Unverified
Commit
a38376fa
authored
May 25, 2025
by
fzyzcjy
Committed by
GitHub
May 24, 2025
Browse files
Refactor attention into multiple stages (#6477)
parent
7a5e6ce1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
26 deletions
+121
-26
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+117
-24
python/sglang/srt/operations_strategy.py
python/sglang/srt/operations_strategy.py
+4
-2
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
a38376fa
...
...
@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
else
:
return
_dispatch_mla_subtype
()
def
op_prepare
(
self
,
state
):
state
.
attn_intermediate_state
=
self
.
forward_prepare
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
)
def
op_core
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
forward_core
(
state
.
pop
(
"attn_intermediate_state"
)
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
s
=
self
.
forward_prepare
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
return
self
.
forward_core
(
s
)
def
forward_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
):
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
not
self
.
o_proj
.
reduce_results
),
"short-circuiting allreduce will lead to hangs"
return
hidden_states
return
hidden_states
,
None
,
forward_batch
,
None
attn_forward_method
=
self
.
dispatch_attn_forward_method
(
forward_batch
)
if
attn_forward_method
==
AttnForwardMethod
.
MHA
:
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
inner_state
=
self
.
forward_normal_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv
(
positions
,
hidden_states
,
forward_batch
inner_state
=
self
.
forward_normal_chunked_kv
_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb
(
inner_state
=
self
.
forward_absorb
_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
return
self
.
forward_absorb_fused_mla_rope
(
positions
,
hidden_states
,
forward_batch
inner_state
=
self
.
forward_absorb_fused_mla_rope
_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
raise
NotImplementedError
return
None
,
attn_forward_method
,
forward_batch
,
inner_state
def
forward_normal
(
def
forward_core
(
self
,
intermediate_state
):
hidden_states
,
attn_forward_method
,
forward_batch
,
inner_state
=
(
intermediate_state
)
if
inner_state
is
None
:
return
hidden_states
if
attn_forward_method
==
AttnForwardMethod
.
MHA
:
return
self
.
forward_normal_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
return
self
.
forward_absorb_fused_mla_rope_core
(
*
inner_state
)
else
:
raise
NotImplementedError
def
forward_normal_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
zero_allocator
:
BumpAllocator
,
):
if
self
.
q_lora_rank
is
not
None
:
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
...
...
@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
return
q
,
k
,
v
,
forward_batch
def
forward_normal_core
(
self
,
q
,
k
,
v
,
forward_batch
):
attn_output
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
attn_output
=
attn_output
.
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
def
forward_absorb
(
def
forward_absorb
_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
...
...
@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
return
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
def
forward_absorb_core
(
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
):
if
self
.
attention_backend
==
"fa3"
or
self
.
attention_backend
==
"flashinfer"
:
attn_output
=
self
.
attn_mqa
(
q_nope_out
,
k_nope
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
,
k_rope
=
k_pe
...
...
@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return
output
def
forward_absorb_fused_mla_rope
(
def
forward_absorb_fused_mla_rope
_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
enable_rope_fusion
=
(
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
)
...
...
@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
)
val_cache_buf
=
key_cache_buf
[...,
:
self
.
kv_lora_rank
]
return
(
q_input
,
key_cache_buf
,
val_cache_buf
,
attn_output
,
kv_indptr
,
kv_indices
,
k_pe_output
,
cos_sin_cache
,
positions
,
attn_logits
,
num_kv_split
,
sm_scale
,
enable_rope_fusion
,
k_input
,
forward_batch
,
zero_allocator
,
)
def
forward_absorb_fused_mla_rope_core
(
self
,
q_input
,
key_cache_buf
,
val_cache_buf
,
attn_output
,
kv_indptr
,
kv_indices
,
k_pe_output
,
cos_sin_cache
,
positions
,
attn_logits
,
num_kv_split
,
sm_scale
,
enable_rope_fusion
,
k_input
,
forward_batch
,
zero_allocator
,
):
decode_attention_fwd_grouped_rope
(
q_input
,
key_cache_buf
,
...
...
@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return
accum_output
def
forward_normal_chunked_kv
(
def
forward_normal_chunked_kv
_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
zero_allocator
:
BumpAllocator
,
):
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
# To avoid this, we split the kv cache into chunks and process them one after another.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
...
...
@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
return
q
,
k
,
v
,
forward_batch
def
forward_normal_chunked_kv_core
(
self
,
q
,
k
,
v
,
forward_batch
):
# Do mha for extended part without prefix
forward_batch
.
set_attn_attend_prefix_cache
(
False
)
attn_output
,
lse
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
...
...
@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
)
)
def
op_attn
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
self_attn
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
)
def
op_comm_prepare_mlp
(
self
,
state
):
state
.
hidden_states_mlp_input
,
state
.
residual_after_comm_pre_mlp
=
(
self
.
layer_communicator
.
prepare_mlp
(
...
...
python/sglang/srt/operations_strategy.py
View file @
a38376fa
...
...
@@ -7,7 +7,8 @@ def compute_layer_operations(
if
not
layer
.
is_layer_sparse
:
return
[
layer
.
op_comm_prepare_attn
,
layer
.
op_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_mlp
,
layer
.
op_comm_postprocess_layer
,
...
...
@@ -16,7 +17,8 @@ def compute_layer_operations(
# Will add TBO operation orders here
return
[
layer
.
op_comm_prepare_attn
,
layer
.
op_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_shared_experts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment