Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a38376fa
"src/fastertransformer/kernels/bert_preprocess_kernels.cu" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Unverified
Commit
a38376fa
authored
May 25, 2025
by
fzyzcjy
Committed by
GitHub
May 24, 2025
Browse files
Refactor attention into multiple stages (#6477)
parent
7a5e6ce1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
121 additions
and
26 deletions
+121
-26
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+117
-24
python/sglang/srt/operations_strategy.py
python/sglang/srt/operations_strategy.py
+4
-2
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
a38376fa
...
...
@@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module):
else
:
return
_dispatch_mla_subtype
()
def
op_prepare
(
self
,
state
):
state
.
attn_intermediate_state
=
self
.
forward_prepare
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
)
def
op_core
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
forward_core
(
state
.
pop
(
"attn_intermediate_state"
)
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
s
=
self
.
forward_prepare
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
return
self
.
forward_core
(
s
)
def
forward_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
):
if
hidden_states
.
shape
[
0
]
==
0
:
assert
(
not
self
.
o_proj
.
reduce_results
),
"short-circuiting allreduce will lead to hangs"
return
hidden_states
return
hidden_states
,
None
,
forward_batch
,
None
attn_forward_method
=
self
.
dispatch_attn_forward_method
(
forward_batch
)
if
attn_forward_method
==
AttnForwardMethod
.
MHA
:
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
inner_state
=
self
.
forward_normal_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv
(
positions
,
hidden_states
,
forward_batch
inner_state
=
self
.
forward_normal_chunked_kv
_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb
(
inner_state
=
self
.
forward_absorb
_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
return
self
.
forward_absorb_fused_mla_rope
(
positions
,
hidden_states
,
forward_batch
inner_state
=
self
.
forward_absorb_fused_mla_rope
_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
raise
NotImplementedError
return
None
,
attn_forward_method
,
forward_batch
,
inner_state
def
forward_core
(
self
,
intermediate_state
):
hidden_states
,
attn_forward_method
,
forward_batch
,
inner_state
=
(
intermediate_state
)
if
inner_state
is
None
:
return
hidden_states
if
attn_forward_method
==
AttnForwardMethod
.
MHA
:
return
self
.
forward_normal_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MHA_CHUNKED_KV
:
return
self
.
forward_normal_chunked_kv_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb_core
(
*
inner_state
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
return
self
.
forward_absorb_fused_mla_rope_core
(
*
inner_state
)
else
:
raise
NotImplementedError
def
forward_normal
(
def
forward_normal
_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
zero_allocator
:
BumpAllocator
,
):
if
self
.
q_lora_rank
is
not
None
:
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
...
...
@@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
return
q
,
k
,
v
,
forward_batch
def
forward_normal_core
(
self
,
q
,
k
,
v
,
forward_batch
):
attn_output
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
attn_output
=
attn_output
.
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
def
forward_absorb
(
def
forward_absorb
_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
...
...
@@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
return
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
def
forward_absorb_core
(
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
):
if
self
.
attention_backend
==
"fa3"
or
self
.
attention_backend
==
"flashinfer"
:
attn_output
=
self
.
attn_mqa
(
q_nope_out
,
k_nope
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
,
k_rope
=
k_pe
...
...
@@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return
output
def
forward_absorb_fused_mla_rope
(
def
forward_absorb_fused_mla_rope
_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
torch
.
Tensor
:
):
enable_rope_fusion
=
(
os
.
getenv
(
"SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION"
,
"1"
)
==
"1"
)
...
...
@@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module):
)
val_cache_buf
=
key_cache_buf
[...,
:
self
.
kv_lora_rank
]
return
(
q_input
,
key_cache_buf
,
val_cache_buf
,
attn_output
,
kv_indptr
,
kv_indices
,
k_pe_output
,
cos_sin_cache
,
positions
,
attn_logits
,
num_kv_split
,
sm_scale
,
enable_rope_fusion
,
k_input
,
forward_batch
,
zero_allocator
,
)
def
forward_absorb_fused_mla_rope_core
(
self
,
q_input
,
key_cache_buf
,
val_cache_buf
,
attn_output
,
kv_indptr
,
kv_indices
,
k_pe_output
,
cos_sin_cache
,
positions
,
attn_logits
,
num_kv_split
,
sm_scale
,
enable_rope_fusion
,
k_input
,
forward_batch
,
zero_allocator
,
):
decode_attention_fwd_grouped_rope
(
q_input
,
key_cache_buf
,
...
...
@@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module):
return
accum_output
def
forward_normal_chunked_kv
(
def
forward_normal_chunked_kv
_prepare
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
zero_allocator
:
BumpAllocator
,
):
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
# To avoid this, we split the kv cache into chunks and process them one after another.
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
...
...
@@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
return
q
,
k
,
v
,
forward_batch
def
forward_normal_chunked_kv_core
(
self
,
q
,
k
,
v
,
forward_batch
):
# Do mha for extended part without prefix
forward_batch
.
set_attn_attend_prefix_cache
(
False
)
attn_output
,
lse
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
...
...
@@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module):
)
)
def
op_attn
(
self
,
state
):
state
.
hidden_states_after_attn
=
self
.
self_attn
(
positions
=
state
.
positions
,
hidden_states
=
state
.
pop
(
"hidden_states_after_comm_pre_attn"
),
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
)
def
op_comm_prepare_mlp
(
self
,
state
):
state
.
hidden_states_mlp_input
,
state
.
residual_after_comm_pre_mlp
=
(
self
.
layer_communicator
.
prepare_mlp
(
...
...
python/sglang/srt/operations_strategy.py
View file @
a38376fa
...
...
@@ -7,7 +7,8 @@ def compute_layer_operations(
if
not
layer
.
is_layer_sparse
:
return
[
layer
.
op_comm_prepare_attn
,
layer
.
op_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_mlp
,
layer
.
op_comm_postprocess_layer
,
...
...
@@ -16,7 +17,8 @@ def compute_layer_operations(
# Will add TBO operation orders here
return
[
layer
.
op_comm_prepare_attn
,
layer
.
op_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_shared_experts
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment