Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
27815038
Commit
27815038
authored
Nov 18, 2025
by
niuhb
Browse files
mtp
parent
de61a992
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
68 additions
and
8 deletions
+68
-8
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+53
-2
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+3
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-0
python/sglang/srt/speculative/draft_utils.py
python/sglang/srt/speculative/draft_utils.py
+10
-3
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
27815038
...
...
@@ -151,8 +151,59 @@ class DCUMLABackend(AttentionBackend):
)
else
:
if
not
self
.
skip_prefill
:
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
:
bs
=
forward_batch
.
batch_size
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
seq_lens
=
forward_batch
.
seq_lens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
# 调用 Triton kernel 生成 block_kv_indices
if
use_sglang_create_flashmla_kv_indices_triton
:
dcu_create_flashmla_kv_indices
(
req_to_token_ptr
=
self
.
req_to_token
.
to
(
torch
.
int32
),
req_pool_indices_ptr
=
forward_batch
.
req_pool_indices
.
to
(
torch
.
int32
),
page_kernel_lens_ptr
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
kv_start_idx
=
None
,
kv_indices_ptr
=
block_kv_indices
.
to
(
torch
.
int32
),
req_to_token_ptr_stride
=
self
.
req_to_token
.
stride
(
0
),
kv_indices_ptr_stride
=
max_seqlen_pad
,
)
else
:
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
,
)
# save forward_metadata
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
,
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
...
...
@@ -431,7 +482,7 @@ class DCUMLABackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
=
None
,
):
if
save_kv_cache
:
if
save_kv_cache
and
self
.
num_draft_tokens
==
0
:
#nhb
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
if
((
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
27815038
...
...
@@ -598,6 +598,7 @@ class FlashAttentionBackend(AttentionBackend):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
...
...
@@ -668,9 +669,9 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
)
if
not
self
.
use_mla
:
if
k_rope
is
None
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
#
layer.k_scale, layer.v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
27815038
...
...
@@ -1940,7 +1940,7 @@ class Scheduler(
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
sampling_info
.
is_all_greedy
=
True
#nhb
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
27815038
...
...
@@ -129,6 +129,7 @@ class ForwardMode(IntEnum):
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
)
def
is_cuda_graph
(
self
):
...
...
python/sglang/srt/speculative/draft_utils.py
View file @
27815038
...
...
@@ -237,7 +237,14 @@ class DraftBackendFactory:
return
None
def
_create_dcumla_prefill_backend
(
self
):
logger
.
warning
(
"flashmla prefill backend is not yet supported for draft extend."
# logger.warning(
# "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from
sglang.srt.layers.attention.flashattention_backend
import
(
FlashAttentionBackend
,
)
return
None
return
FlashAttentionBackend
(
self
.
draft_model_runner
,
skip_prefill
=
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment