Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a55cb8b2
Commit
a55cb8b2
authored
Nov 12, 2025
by
linhai1
Browse files
update to start mtp.
parent
39f99208
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
263 additions
and
53 deletions
+263
-53
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+101
-39
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+6
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+22
-9
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+3
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+102
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+16
-0
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
a55cb8b2
...
...
@@ -103,54 +103,112 @@ class DCUMLABackend(AttentionBackend):
skip_prefill
=
False
,
)
def
_build_decode_metadata
(
self
,
forward_batch
:
ForwardBatch
,
seq_lens
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
return
(
mla_metadata
,
num_splits
),
num_splits
,
block_kv_indices
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
# decode用flashmla
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
seq_lens
)
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
else
:
if
not
self
.
skip_prefill
:
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
:
bs
=
forward_batch
.
batch_size
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
seq_lens
=
forward_batch
.
seq_lens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
,
)
# save forward_metadata
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
,
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
...
...
@@ -431,10 +489,10 @@ class DCUMLABackend(AttentionBackend):
):
if
save_kv_cache
:
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
if
(
(
if
(
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
if
not
self
.
skip_prefill
:
return
self
.
flashattn_backend
.
forward_extend
(
...
...
@@ -447,7 +505,12 @@ class DCUMLABackend(AttentionBackend):
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
...
@@ -482,7 +545,6 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a55cb8b2
...
...
@@ -598,6 +598,7 @@ class FlashAttentionBackend(AttentionBackend):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
...
...
@@ -608,10 +609,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
# # Setup local attention if enabled
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if
forward_batch
.
forward_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND_V2
):
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a55cb8b2
...
...
@@ -1940,7 +1940,7 @@ class Scheduler(
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
sampling_info
.
is_all_greedy
=
True
#nhb
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
a55cb8b2
...
...
@@ -123,12 +123,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
def
is_extend_or_draft_extend_or_mixed
(
self
):
#nhb
return
(
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
)
def
is_cuda_graph
(
self
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a55cb8b2
...
...
@@ -2241,6 +2241,7 @@ class ModelRunner:
and
self
.
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
)
if
can_run_graph
:
ret
=
self
.
graph_runner
.
replay
(
forward_batch
,
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
a55cb8b2
...
...
@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc
,
get_target_cache_loc
,
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_extend_after_decode_spec_info
if
is_cuda
():
from
sgl_kernel
import
(
...
...
@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
use_sglang_create_extend_after_decode_spec_info
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
...
@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
if
self
.
use_sglang_create_extend_after_decode_spec_info
:
dcu_create_extend_after_decode_spec_info
(
verified_id
=
batch
.
input_ids
,
seq_lens
=
batch
.
seq_lens
,
accept_lens
=
self
.
accept_length
,
positions
=
self
.
positions
,
new_verified_id
=
self
.
verified_id
,
bs
=
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
)),
)
else
:
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
self
,
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
a55cb8b2
...
...
@@ -131,6 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
a55cb8b2
...
...
@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
out_indices
[
output_idx
]
=
start_loc
*
page_size
+
offset
;
}
}
__global__
void
launch_create_extend_after_decode_spec_info_int32_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int32_t
accept_length
=
accept_lens_ptr
[
pid
];
int32_t
accept_len_cumsum
=
0
;
for
(
int32_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int32_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int32_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
__global__
void
launch_create_extend_after_decode_spec_info_int64_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int64_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int64_t
accept_length
=
accept_lens_ptr
[
pid
];
int64_t
accept_len_cumsum
=
0
;
for
(
int64_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int64_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int64_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
...
...
@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
)
{
const
int32_t
*
verified_id_ptr
;
const
int64_t
*
seq_lens_ptr
;
const
int32_t
*
accept_lens_ptr_int32
;
const
int64_t
*
accept_lens_ptr_int64
;
int64_t
*
positions_ptr
;
int32_t
*
new_verified_id_ptr
;
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
accept_lens
.
dtype
()
==
torch
::
kInt32
)
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int32
=
static_cast
<
const
int32_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int32_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int32
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
else
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int64
=
static_cast
<
const
int64_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int64_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int64
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
};
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
a55cb8b2
...
...
@@ -538,6 +538,14 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
a55cb8b2
...
...
@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip
=
is_hip
()
def
dcu_create_extend_after_decode_spec_info
(
verified_id
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
accept_lens
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
new_verified_id
:
torch
.
Tensor
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_extend_after_decode_spec_info
(
verified_id
,
seq_lens
,
accept_lens
,
positions
,
new_verified_id
,
bs
,
)
def
dcu_alloc_extend_kernel
(
pre_lens_ptr
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment