Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d6e1d28c
Unverified
Commit
d6e1d28c
authored
May 21, 2025
by
fzyzcjy
Committed by
GitHub
May 21, 2025
Browse files
Refactor DeepSeek attention dispatching (#6476)
parent
7c347259
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
19 deletions
+27
-19
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+27
-19
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
d6e1d28c
...
@@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum):
...
@@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum):
# This method can avoid OOM when prefix lengths are long.
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV
=
auto
()
MHA_CHUNKED_KV
=
auto
()
# Use MLA but with fused RoPE
MLA_FUSED_ROPE
=
auto
()
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module):
def
dispatch_attn_forward_method
(
def
dispatch_attn_forward_method
(
self
,
forward_batch
:
ForwardBatch
self
,
forward_batch
:
ForwardBatch
)
->
AttnForwardMethod
:
)
->
AttnForwardMethod
:
def
_dispatch_mla_subtype
():
if
_is_hip
:
if
(
self
.
rocm_fused_decode_mla
and
forward_batch
.
forward_mode
.
is_decode
()
):
return
AttnForwardMethod
.
MLA_FUSED_ROPE
else
:
return
AttnForwardMethod
.
MLA
else
:
return
AttnForwardMethod
.
MLA
if
self
.
attention_backend
==
"flashinfer"
:
if
self
.
attention_backend
==
"flashinfer"
:
# Flashinfer MLA: Do not absorb when enabling ragged prefill
# Flashinfer MLA: Do not absorb when enabling ragged prefill
if
(
if
(
...
@@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
):
return
AttnForwardMethod
.
MHA
return
AttnForwardMethod
.
MHA
else
:
else
:
return
AttnForwardMethod
.
MLA
return
_dispatch_mla_subtype
()
elif
self
.
attention_backend
==
"fa3"
:
elif
self
.
attention_backend
==
"fa3"
:
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if
forward_batch
.
extend_prefix_lens_cpu
is
not
None
:
if
forward_batch
.
extend_prefix_lens_cpu
is
not
None
:
...
@@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
):
return
AttnForwardMethod
.
MHA_CHUNKED_KV
return
AttnForwardMethod
.
MHA_CHUNKED_KV
else
:
else
:
return
AttnForwardMethod
.
MLA
return
_dispatch_mla_subtype
()
else
:
else
:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if
(
if
(
...
@@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
):
return
AttnForwardMethod
.
MHA
return
AttnForwardMethod
.
MHA
else
:
else
:
return
AttnForwardMethod
.
MLA
return
_dispatch_mla_subtype
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module):
return
self
.
forward_normal_chunked_kv
(
return
self
.
forward_normal_chunked_kv
(
positions
,
hidden_states
,
forward_batch
positions
,
hidden_states
,
forward_batch
)
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
return
self
.
forward_absorb_fused_mla_rope
(
positions
,
hidden_states
,
forward_batch
)
else
:
else
:
if
_is_hip
:
raise
NotImplementedError
if
(
self
.
rocm_fused_decode_mla
and
forward_batch
.
forward_mode
.
is_decode
()
):
return
self
.
forward_absorb_fused_mla_rope
(
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
def
forward_normal
(
def
forward_normal
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment