Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0bfe1d14
Unverified
Commit
0bfe1d14
authored
Oct 31, 2025
by
Xinyuan Tong
Committed by
GitHub
Oct 31, 2025
Browse files
fa3 & trtllm_mha spec overlap (#11874)
parent
5f98b7fe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
21 deletions
+16
-21
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+7
-6
python/sglang/srt/layers/attention/trtllm_mha_backend.py
python/sglang/srt/layers/attention/trtllm_mha_backend.py
+3
-4
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+6
-11
No files found.
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
0bfe1d14
...
@@ -584,7 +584,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -584,7 +584,9 @@ class FlashAttentionBackend(AttentionBackend):
metadata
,
metadata_expand
metadata
,
metadata_expand
)
)
elif
forward_batch
.
forward_mode
.
is_extend_or_draft_extend_or_mixed
():
elif
forward_batch
.
forward_mode
.
is_extend_or_draft_extend_or_mixed
(
include_draft_extend_v2
=
True
):
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
cache_seqlens_int32
=
seqlens_in_batch
.
to
(
torch
.
int32
)
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
max_seq_len_k
=
forward_batch
.
seq_lens_cpu
.
max
().
item
()
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
...
@@ -594,10 +596,9 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -594,10 +596,9 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
if
(
if
any
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
forward_batch
.
extend_prefix_lens_cpu
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
...
@@ -826,7 +827,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -826,7 +827,7 @@ class FlashAttentionBackend(AttentionBackend):
if
(
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
forward_batch
.
attn_attend_prefix_cache
is
not
None
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
)
):
):
# Do multi-head attention with chunked prefix cache
# Do multi-head attention with chunked prefix cache
if
forward_batch
.
attn_attend_prefix_cache
:
if
forward_batch
.
attn_attend_prefix_cache
:
...
...
python/sglang/srt/layers/attention/trtllm_mha_backend.py
View file @
0bfe1d14
...
@@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
...
@@ -488,10 +488,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
forward_batch
.
req_pool_indices
,
:
metadata
.
max_seq_len_k
]
]
if
(
if
any
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
forward_batch
.
extend_prefix_lens_cpu
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
.
is_draft_extend
(
include_v2
=
True
):
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
metadata
.
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
0bfe1d14
...
@@ -90,11 +90,7 @@ class ForwardMode(IntEnum):
...
@@ -90,11 +90,7 @@ class ForwardMode(IntEnum):
self
==
ForwardMode
.
EXTEND
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
(
or
(
include_draft_extend_v2
and
self
==
ForwardMode
.
DRAFT_EXTEND_V2
)
self
==
ForwardMode
.
DRAFT_EXTEND_V2
if
include_draft_extend_v2
else
False
)
or
self
==
ForwardMode
.
TARGET_VERIFY
or
self
==
ForwardMode
.
TARGET_VERIFY
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
SPLIT_PREFILL
)
)
...
@@ -115,22 +111,21 @@ class ForwardMode(IntEnum):
...
@@ -115,22 +111,21 @@ class ForwardMode(IntEnum):
return
self
==
ForwardMode
.
TARGET_VERIFY
return
self
==
ForwardMode
.
TARGET_VERIFY
def
is_draft_extend
(
self
,
include_v2
:
bool
=
False
):
def
is_draft_extend
(
self
,
include_v2
:
bool
=
False
):
if
include_v2
:
return
self
==
ForwardMode
.
DRAFT_EXTEND
or
(
return
(
include_v2
and
self
==
ForwardMode
.
DRAFT_EXTEND_V2
self
==
ForwardMode
.
DRAFT_EXTEND_V2
or
self
==
ForwardMode
.
DRAFT_EXTEND
)
)
return
self
==
ForwardMode
.
DRAFT_EXTEND
def
is_draft_extend_v2
(
self
):
def
is_draft_extend_v2
(
self
):
# For fixed shape logits output in v2 eagle worker
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
def
is_extend_or_draft_extend_or_mixed
(
self
,
include_draft_extend_v2
:
bool
=
False
):
return
(
return
(
self
==
ForwardMode
.
EXTEND
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
(
include_draft_extend_v2
and
self
==
ForwardMode
.
DRAFT_EXTEND_V2
)
)
)
def
is_cuda_graph
(
self
):
def
is_cuda_graph
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment