Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a55cb8b2
"vscode:/vscode.git/clone" did not exist on "7722c11c1d2a2da5b914f3e043b7e8fcd182c0f5"
Commit
a55cb8b2
authored
Nov 12, 2025
by
linhai1
Browse files
update to start mtp.
parent
39f99208
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
263 additions
and
53 deletions
+263
-53
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+101
-39
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+6
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+22
-9
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+3
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+102
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+16
-0
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
a55cb8b2
...
...
@@ -103,54 +103,112 @@ class DCUMLABackend(AttentionBackend):
skip_prefill
=
False
,
)
def
_build_decode_metadata
(
self
,
forward_batch
:
ForwardBatch
,
seq_lens
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
return
(
mla_metadata
,
num_splits
),
num_splits
,
block_kv_indices
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
forward_batch
.
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
# decode用flashmla
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
seq_lens
)
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
else
:
if
not
self
.
skip_prefill
:
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
:
bs
=
forward_batch
.
batch_size
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
seq_lens
=
forward_batch
.
seq_lens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
,
)
# save forward_metadata
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
,
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
...
...
@@ -431,10 +489,10 @@ class DCUMLABackend(AttentionBackend):
):
if
save_kv_cache
:
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
if
(
(
if
(
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
if
not
self
.
skip_prefill
:
return
self
.
flashattn_backend
.
forward_extend
(
...
...
@@ -447,7 +505,12 @@ class DCUMLABackend(AttentionBackend):
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
...
...
@@ -482,7 +545,6 @@ class DCUMLABackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
a55cb8b2
...
...
@@ -598,6 +598,7 @@ class FlashAttentionBackend(AttentionBackend):
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
...
...
@@ -608,10 +609,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
# # Setup local attention if enabled
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if
forward_batch
.
forward_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND_V2
):
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a55cb8b2
...
...
@@ -1940,7 +1940,7 @@ class Scheduler(
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
sampling_info
.
is_all_greedy
=
True
#nhb
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
a55cb8b2
...
...
@@ -123,12 +123,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
def
is_extend_or_draft_extend_or_mixed
(
self
):
#nhb
return
(
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
)
def
is_cuda_graph
(
self
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
a55cb8b2
...
...
@@ -2241,6 +2241,7 @@ class ModelRunner:
and
self
.
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
)
if
can_run_graph
:
ret
=
self
.
graph_runner
.
replay
(
forward_batch
,
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
a55cb8b2
...
...
@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc
,
get_target_cache_loc
,
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_extend_after_decode_spec_info
if
is_cuda
():
from
sgl_kernel
import
(
...
...
@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
use_sglang_create_extend_after_decode_spec_info
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
...
@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
if
self
.
use_sglang_create_extend_after_decode_spec_info
:
dcu_create_extend_after_decode_spec_info
(
verified_id
=
batch
.
input_ids
,
seq_lens
=
batch
.
seq_lens
,
accept_lens
=
self
.
accept_length
,
positions
=
self
.
positions
,
new_verified_id
=
self
.
verified_id
,
bs
=
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
)),
)
else
:
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
self
,
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
a55cb8b2
...
...
@@ -131,6 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
a55cb8b2
...
...
@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
out_indices
[
output_idx
]
=
start_loc
*
page_size
+
offset
;
}
}
__global__
void
launch_create_extend_after_decode_spec_info_int32_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int32_t
accept_length
=
accept_lens_ptr
[
pid
];
int32_t
accept_len_cumsum
=
0
;
for
(
int32_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int32_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int32_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
__global__
void
launch_create_extend_after_decode_spec_info_int64_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int64_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int64_t
accept_length
=
accept_lens_ptr
[
pid
];
int64_t
accept_len_cumsum
=
0
;
for
(
int64_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int64_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int64_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
...
...
@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
)
{
const
int32_t
*
verified_id_ptr
;
const
int64_t
*
seq_lens_ptr
;
const
int32_t
*
accept_lens_ptr_int32
;
const
int64_t
*
accept_lens_ptr_int64
;
int64_t
*
positions_ptr
;
int32_t
*
new_verified_id_ptr
;
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
accept_lens
.
dtype
()
==
torch
::
kInt32
)
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int32
=
static_cast
<
const
int32_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int32_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int32
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
else
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int64
=
static_cast
<
const
int64_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int64_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int64
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
};
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
a55cb8b2
...
...
@@ -538,6 +538,14 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
a55cb8b2
...
...
@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip
=
is_hip
()
def
dcu_create_extend_after_decode_spec_info
(
verified_id
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
accept_lens
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
new_verified_id
:
torch
.
Tensor
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_extend_after_decode_spec_info
(
verified_id
,
seq_lens
,
accept_lens
,
positions
,
new_verified_id
,
bs
,
)
def
dcu_alloc_extend_kernel
(
pre_lens_ptr
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment