Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b6e7f2a9
Commit
b6e7f2a9
authored
Nov 18, 2025
by
lizhigong
Browse files
Merge branch 'v0.5.4_dev_linhai' into 'v0.5.4_dev'
V0.5.4 dev linhai See merge request OpenDAS/sglang!28
parents
de61a992
4c45697e
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
405 additions
and
181 deletions
+405
-181
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+106
-47
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+125
-107
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+15
-14
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+2
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-2
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+22
-9
sgl-kernel/csrc/common_extension_rocm.cc
sgl-kernel/csrc/common_extension_rocm.cc
+3
-0
sgl-kernel/csrc/kvcacheio/transfer.cu
sgl-kernel/csrc/kvcacheio/transfer.cu
+102
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+8
-0
sgl-kernel/python/sgl_kernel/kvcacheio.py
sgl-kernel/python/sgl_kernel/kvcacheio.py
+16
-0
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
b6e7f2a9
...
@@ -103,54 +103,112 @@ class DCUMLABackend(AttentionBackend):
...
@@ -103,54 +103,112 @@ class DCUMLABackend(AttentionBackend):
skip_prefill
=
False
,
skip_prefill
=
False
,
)
)
def
_build_decode_metadata
(
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
self
,
forward_batch
:
ForwardBatch
,
seq_lens
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
forward_batch
.
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices
=
torch
.
full
(
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
-
1
,
)
dtype
=
torch
.
int32
,
create_flashmla_kv_indices_triton
[(
bs
,)](
device
=
forward_batch
.
seq_lens
.
device
self
.
req_to_token
,
)
forward_batch
.
req_pool_indices
,
create_flashmla_kv_indices_triton
[(
bs
,)](
seq_lens
,
self
.
req_to_token
,
None
,
forward_batch
.
req_pool_indices
,
block_kv_indices
,
forward_batch
.
seq_lens
,
self
.
req_to_token
.
stride
(
0
),
None
,
max_seqlen_pad
,
block_kv_indices
,
)
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
mla_metadata
,
num_splits
=
get_mla_metadata
(
)
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
return
(
mla_metadata
,
num_splits
),
num_splits
,
block_kv_indices
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
mla_metadata
,
num_splits
=
get_mla_metadata
(
# decode用flashmla
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
num_q_heads
,
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
1
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
seq_lens
)
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
mla_metadata
,
num_splits
,
block_kv_indices
)
)
else
:
else
:
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
:
bs
=
forward_batch
.
batch_size
seq_lens_cpu
=
forward_batch
.
seq_lens_cpu
seq_lens
=
forward_batch
.
seq_lens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
,
)
# 调用 Triton kernel 生成 block_kv_indices
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
# MLA
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
,
)
# save forward_metadata
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits
,
block_kv_indices
,
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
...
@@ -389,7 +447,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -389,7 +447,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
...
@@ -401,7 +459,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -401,7 +459,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
)
.
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
...
@@ -411,7 +469,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -411,7 +469,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
)
.
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
)
)
...
@@ -431,12 +489,9 @@ class DCUMLABackend(AttentionBackend):
...
@@ -431,12 +489,9 @@ class DCUMLABackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
=
None
,
sinks
=
None
,
):
):
if
save_kv_cache
:
if
(
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
if
((
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
):
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
return
self
.
flashattn_backend
.
forward_extend
(
return
self
.
flashattn_backend
.
forward_extend
(
...
@@ -449,14 +504,19 @@ class DCUMLABackend(AttentionBackend):
...
@@ -449,14 +504,19 @@ class DCUMLABackend(AttentionBackend):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
...
@@ -468,7 +528,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -468,7 +528,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
...
@@ -478,13 +538,12 @@ class DCUMLABackend(AttentionBackend):
...
@@ -478,13 +538,12 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
class
DCUMLAMultiStepDraftBackend
:
class
DCUMLAMultiStepDraftBackend
:
"""
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
Wrap multiple flashmla attention backends as one for multiple consecutive
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
b6e7f2a9
...
@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
if
self
.
is_hybrid
:
if
self
.
is_hybrid
:
self
.
full_to_swa_index_mapping
=
(
self
.
full_to_swa_index_mapping
=
(
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
...
@@ -598,6 +600,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -598,6 +600,7 @@ class FlashAttentionBackend(AttentionBackend):
if
(
if
(
any
(
forward_batch
.
extend_prefix_lens_cpu
)
any
(
forward_batch
.
extend_prefix_lens_cpu
)
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND_V2
#nhb
):
):
extend_seq_lens
=
forward_batch
.
extend_seq_lens
extend_seq_lens
=
forward_batch
.
extend_seq_lens
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
metadata
.
max_seq_len_q
=
max
(
forward_batch
.
extend_seq_lens_cpu
)
...
@@ -608,10 +611,13 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -608,10 +611,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
max_seq_len_q
=
metadata
.
max_seq_len_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
metadata
.
cu_seqlens_q
=
metadata
.
cu_seqlens_k
# Setup local attention if enabled
# # Setup local attention if enabled
if
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
:
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if
forward_batch
.
forward_mode
in
(
ForwardMode
.
EXTEND
,
ForwardMode
.
DRAFT_EXTEND_V2
):
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
self
.
_init_local_attn_metadata
(
forward_batch
,
metadata
,
device
)
# Encoder metadata for cross attention
# Encoder metadata for cross attention
if
forward_batch
.
encoder_lens
is
not
None
:
if
forward_batch
.
encoder_lens
is
not
None
:
assert
(
assert
(
...
@@ -668,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -668,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
if
not
self
.
use_mla
:
# if not self.use_mla:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
k_rope
is
None
:
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
if
not
self
.
use_mla
:
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -690,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -690,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
)
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...
@@ -774,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -774,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
window_size
=
(
-
1
,
-
1
)
result
=
flash_attn_with_kvcache
(
if
forward_batch
.
attn_attend_prefix_cache
:
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
k_cache
=
key_cache
,
# MHA for chunked prefix kv cache when running model with MLA
v_cache
=
value_cache
,
assert
forward_batch
.
prefix_chunk_idx
is
not
None
page_table
=
page_table
,
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
cache_seqlens
=
cache_seqlens
,
assert
forward_batch
.
prefix_chunk_max_seq_lens
is
not
None
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
chunk_idx
=
forward_batch
.
prefix_chunk_idx
max_seqlen_q
=
max_seqlen_q
,
assert
chunk_idx
>=
0
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
assert
forward_batch
.
mha_return_lse
window_size
=
window_size
,
output
=
flash_attn_varlen_func
(
softcap
=
layer
.
logit_cap
,
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_descale
=
k_descale
,
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v_descale
=
v_descale
,
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
return_softmax_lse
=
use_cascade_attn
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
num_splits
=
self
.
num_splits
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
**
kwargs
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
)
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
if
use_cascade_attn
:
o
,
softmax_lse
,
*
rest
=
result
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2_wrapper
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
else
:
else
:
o
=
result
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
)
if
forward_batch
.
mha_return_lse
:
output
,
lse
,
*
rest
=
output
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
return
output
,
lse
return
output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
else
:
if
(
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
forward_batch
.
attn_attend_prefix_cache
is
not
None
...
@@ -851,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -851,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
False
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -865,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -865,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -974,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -974,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
if
not
self
.
use_mla
:
# if not self.use_mla:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
k_rope
is
None
:
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
if
not
self
.
use_mla
:
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -1019,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1019,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
if
sinks
is
not
None
:
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...
@@ -1033,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1033,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
layer
.
layer_id
)
)
...
@@ -1085,65 +1106,62 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1085,65 +1106,62 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
cu_seqlens_q
=
metadata
.
cu_seqlens_q
max_seqlen_q
=
metadata
.
max_seq_len_q
page_table
=
metadata
.
page_table
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max
_seqlen
_q
=
metadata
.
max
_seq
_
len
_q
cache
_seqlen
s
=
metadata
.
cache
_seqlen
s_int32
q_reshaped
=
q
.
contiguous
()
.
view
(
key_cache
=
key_cache
.
view
(
-
1
,
layer
.
tp_
q
_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_
k
_head_num
,
layer
.
head_dim
)
)
value_cache
=
value_cache
.
view
(
# Default: single-token self-attention
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
result
=
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
)
if
use_cascade_attn
:
if
layer
.
is_cross_attention
:
o
,
softmax_lse
,
*
rest
=
result
page_table
=
metadata
.
encoder_page_table
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
cache_seqlens
=
metadata
.
encoder_lens_int32
flash_attn_with_kvcache
(
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
q
=
q_reshaped
,
window_size
=
(
-
1
,
-
1
)
k_cache
=
key_cache
,
if
max_seqlen_q
>
1
:
v_cache
=
value_cache
,
result
=
flash_attn_varlen_func
(
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
cu_seqlens_k
=
cu_seqlens_k
,
softmax_scale
=
layer
.
scaling
,
max_seqlen_q
=
max_seqlen_q
,
causal
=
False
,
max_seqlen_k
=
max_seqlen_q
,
window_size
=
window_size
,
softmax_scale
=
layer
.
scaling
,
softcap
=
layer
.
logit_cap
,
causal
=
True
,
k_descale
=
k_descale
,
window_size
=
window_size
,
v_descale
=
v_descale
,
softcap
=
layer
.
logit_cap
,
return_softmax_lse
=
True
,
k_descale
=
k_descale
,
num_splits
=
self
.
num_splits
,
v_descale
=
v_descale
,
**
kwargs
,
return_softmax_lse
=
use_cascade_attn
,
)
num_splits
=
self
.
num_splits
,
)
**
kwargs
,
o
,
_
=
merge_state_v2
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
)
else
:
else
:
o
=
result
result
=
flash_attn_with_kvcache
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
page_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
)
o
=
result
else
:
else
:
# Do absorbed multi-latent attention
# Do absorbed multi-latent attention
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
kv_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
).
to
(
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
View file @
b6e7f2a9
...
@@ -41,18 +41,18 @@ def flash_attn_with_kvcache(
...
@@ -41,18 +41,18 @@ def flash_attn_with_kvcache(
ver
=
3
,
ver
=
3
,
):
):
return
flash_attn_with_kvcache_interface
(
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
,
k_cache
=
k_cache
.
view
(
q
.
dtype
)
,
v_cache
=
v_cache
,
v_cache
=
v_cache
.
view
(
q
.
dtype
)
,
block_table
=
page_table
,
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
softcap
,
softcap
=
softcap
,
return_softmax_lse
=
return_softmax_lse
,
return_softmax_lse
=
return_softmax_lse
,
num_splits
=
num_splits
,
num_splits
=
num_splits
,
)
)
def
flash_attn_varlen_func
(
def
flash_attn_varlen_func
(
q
,
q
,
...
@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
...
@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
):
):
return
flash_attn_varlen_func_interface
(
return
flash_attn_varlen_func_interface
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
.
view
(
q
.
dtype
)
,
v
=
v
,
v
=
v
.
view
(
q
.
dtype
)
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
...
@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
...
@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
return_attn_probs
=
return_softmax_lse
,
return_attn_probs
=
return_softmax_lse
,
softcap
=
softcap
,
)
)
\ No newline at end of file
python/sglang/srt/managers/schedule_batch.py
View file @
b6e7f2a9
...
@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
loc_tensor
=
torch
.
tensor
([
-
1
],
device
=
self
.
device
)
# Allocate memory
# Allocate memory
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
b6e7f2a9
...
@@ -1940,7 +1940,7 @@ class Scheduler(
...
@@ -1940,7 +1940,7 @@ class Scheduler(
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
=
batch_result
.
next_draft_input
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
spec_info
.
future_indices
=
future_indices
batch
.
sampling_info
.
is_all_greedy
=
True
#nhb
# batch.spec_info = EagleDraftInput(
# batch.spec_info = EagleDraftInput(
# future_indices=future_indices,
# future_indices=future_indices,
# verify_done=batch_result.next_draft_input.verify_done,
# verify_done=batch_result.next_draft_input.verify_done,
...
...
python/sglang/srt/mem_cache/common.py
View file @
b6e7f2a9
...
@@ -356,7 +356,8 @@ def alloc_for_extend(
...
@@ -356,7 +356,8 @@ def alloc_for_extend(
else
:
else
:
# Paged allocation - build last_loc
# Paged allocation - build last_loc
last_loc
=
[
last_loc
=
[
(
t
[
-
1
:]
if
len
(
t
)
>
0
else
torch
.
tensor
([
-
1
],
device
=
batch
.
device
))
# (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
(
t
[
-
1
:]
if
len
(
t
)
>
0
else
batch
.
loc_tensor
)
for
t
in
prefix_tensors
for
t
in
prefix_tensors
]
]
out_cache_loc
=
alloc_paged_token_slots_extend
(
out_cache_loc
=
alloc_paged_token_slots_extend
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
b6e7f2a9
...
@@ -123,12 +123,13 @@ class ForwardMode(IntEnum):
...
@@ -123,12 +123,13 @@ class ForwardMode(IntEnum):
# For fixed shape logits output in v2 eagle worker
# For fixed shape logits output in v2 eagle worker
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
return
self
==
ForwardMode
.
DRAFT_EXTEND_V2
def
is_extend_or_draft_extend_or_mixed
(
self
):
def
is_extend_or_draft_extend_or_mixed
(
self
):
#nhb
return
(
return
(
self
==
ForwardMode
.
EXTEND
self
==
ForwardMode
.
EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
DRAFT_EXTEND
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
MIXED
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
SPLIT_PREFILL
or
self
==
ForwardMode
.
DRAFT_EXTEND_V2
)
)
def
is_cuda_graph
(
self
):
def
is_cuda_graph
(
self
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
b6e7f2a9
...
@@ -2241,6 +2241,7 @@ class ModelRunner:
...
@@ -2241,6 +2241,7 @@ class ModelRunner:
and
self
.
graph_runner
and
self
.
graph_runner
and
self
.
graph_runner
.
can_run
(
forward_batch
)
and
self
.
graph_runner
.
can_run
(
forward_batch
)
)
)
if
can_run_graph
:
if
can_run_graph
:
ret
=
self
.
graph_runner
.
replay
(
ret
=
self
.
graph_runner
.
replay
(
forward_batch
,
forward_batch
,
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
b6e7f2a9
...
@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
...
@@ -37,7 +37,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc
,
get_src_tgt_cache_loc
,
get_target_cache_loc
,
get_target_cache_loc
,
)
)
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
,
get_bool_env_var
from
sgl_kernel.kvcacheio
import
dcu_create_extend_after_decode_spec_info
if
is_cuda
():
if
is_cuda
():
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
...
@@ -615,6 +616,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
new_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
verify_done
:
Optional
[
torch
.
cuda
.
Event
]
=
None
use_sglang_create_extend_after_decode_spec_info
=
get_bool_env_var
(
"SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO"
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
super
().
__init__
(
SpecInputType
.
EAGLE_DRAFT
)
...
@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
...
@@ -679,14 +682,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
positions
=
torch
.
empty_like
(
batch
.
input_ids
,
dtype
=
torch
.
long
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
self
.
verified_id
=
torch
.
empty_like
(
self
.
accept_length
,
dtype
=
torch
.
int32
)
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
if
self
.
use_sglang_create_extend_after_decode_spec_info
:
batch
.
input_ids
,
dcu_create_extend_after_decode_spec_info
(
batch
.
seq_lens
,
verified_id
=
batch
.
input_ids
,
self
.
accept_length
,
seq_lens
=
batch
.
seq_lens
,
self
.
positions
,
accept_lens
=
self
.
accept_length
,
self
.
verified_id
,
positions
=
self
.
positions
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
new_verified_id
=
self
.
verified_id
,
)
bs
=
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
)),
)
else
:
create_extend_after_decode_spec_info
[(
len
(
batch
.
seq_lens
),)](
batch
.
input_ids
,
batch
.
seq_lens
,
self
.
accept_length
,
self
.
positions
,
self
.
verified_id
,
next_power_of_2
(
max
(
speculative_num_steps
+
1
,
len
(
batch
.
seq_lens
))),
)
def
generate_attn_arg_prefill
(
def
generate_attn_arg_prefill
(
self
,
self
,
...
...
sgl-kernel/csrc/common_extension_rocm.cc
View file @
b6e7f2a9
...
@@ -131,6 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -131,6 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
/*
* From csrc/kvcacheio
* From csrc/kvcacheio
*/
*/
m
.
def
(
"dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()"
);
m
.
impl
(
"dcu_create_extend_after_decode_spec_info"
,
torch
::
kCUDA
,
&
dcu_create_extend_after_decode_spec_info
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
def
(
"dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
impl
(
"dcu_alloc_extend_kernel"
,
torch
::
kCUDA
,
&
dcu_alloc_extend_kernel
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
m
.
def
(
"dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()"
);
...
...
sgl-kernel/csrc/kvcacheio/transfer.cu
View file @
b6e7f2a9
...
@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
...
@@ -693,6 +693,65 @@ __global__ void launch_alloc_extend_kernel(
out_indices
[
output_idx
]
=
start_loc
*
page_size
+
offset
;
out_indices
[
output_idx
]
=
start_loc
*
page_size
+
offset
;
}
}
}
}
__global__
void
launch_create_extend_after_decode_spec_info_int32_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int32_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int32_t
accept_length
=
accept_lens_ptr
[
pid
];
int32_t
accept_len_cumsum
=
0
;
for
(
int32_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int32_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int32_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
__global__
void
launch_create_extend_after_decode_spec_info_int64_kernel
(
const
int32_t
*
verified_id_ptr
,
const
int64_t
*
seq_lens_ptr
,
const
int64_t
*
accept_lens_ptr
,
int64_t
*
positions_ptr
,
int32_t
*
new_verified_id_ptr
,
int64_t
bs
)
{
int64_t
pid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
pid
>=
bs
)
return
;
int64_t
seq_length
=
seq_lens_ptr
[
pid
];
int64_t
accept_length
=
accept_lens_ptr
[
pid
];
int64_t
accept_len_cumsum
=
0
;
for
(
int64_t
offset
=
0
;
offset
<
pid
;
offset
++
)
{
accept_len_cumsum
+=
accept_lens_ptr
[
offset
];
}
int64_t
*
positions_ptr1
=
positions_ptr
+
accept_len_cumsum
;
for
(
int64_t
offset
=
0
;
offset
<
accept_length
&&
offset
<
bs
;
offset
++
)
{
positions_ptr1
[
offset
]
=
seq_length
-
accept_length
+
offset
;
}
int64_t
verified_idx
=
accept_len_cumsum
+
accept_length
-
1
;
new_verified_id_ptr
[
pid
]
=
verified_id_ptr
[
verified_idx
];
}
void
dcu_alloc_decode_kernel
(
void
dcu_alloc_decode_kernel
(
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
...
@@ -714,6 +773,49 @@ void dcu_alloc_decode_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
}
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
)
{
const
int32_t
*
verified_id_ptr
;
const
int64_t
*
seq_lens_ptr
;
const
int32_t
*
accept_lens_ptr_int32
;
const
int64_t
*
accept_lens_ptr_int64
;
int64_t
*
positions_ptr
;
int32_t
*
new_verified_id_ptr
;
int64_t
block_size
=
64
;
int64_t
grid_size
=
(
bs
+
block_size
-
1
)
/
block_size
;
cudaStream_t
torch_current_stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
accept_lens
.
dtype
()
==
torch
::
kInt32
)
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int32
=
static_cast
<
const
int32_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int32_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int32
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
else
{
verified_id_ptr
=
static_cast
<
const
int32_t
*>
(
verified_id
.
data_ptr
());
seq_lens_ptr
=
static_cast
<
const
int64_t
*>
(
seq_lens
.
data_ptr
());
accept_lens_ptr_int64
=
static_cast
<
const
int64_t
*>
(
accept_lens
.
data_ptr
());
positions_ptr
=
static_cast
<
int64_t
*>
(
positions
.
data_ptr
());
new_verified_id_ptr
=
static_cast
<
int32_t
*>
(
new_verified_id
.
data_ptr
());
launch_create_extend_after_decode_spec_info_int64_kernel
<<<
grid_size
,
block_size
,
0
,
torch_current_stream
>>>
(
verified_id_ptr
,
seq_lens_ptr
,
accept_lens_ptr_int64
,
positions_ptr
,
new_verified_id_ptr
,
bs
);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
};
void
dcu_alloc_extend_kernel
(
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
b6e7f2a9
...
@@ -538,6 +538,14 @@ void segment_packbits(
...
@@ -538,6 +538,14 @@ void segment_packbits(
/*
/*
* From csrc/kvcacheio
* From csrc/kvcacheio
*/
*/
void
dcu_create_extend_after_decode_spec_info
(
const
at
::
Tensor
verified_id
,
const
at
::
Tensor
seq_lens
,
const
at
::
Tensor
accept_lens
,
at
::
Tensor
positions
,
at
::
Tensor
new_verified_id
,
int64_t
bs
);
void
dcu_alloc_extend_kernel
(
void
dcu_alloc_extend_kernel
(
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
pre_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
const
at
::
Tensor
seq_lens_ptr
,
...
...
sgl-kernel/python/sgl_kernel/kvcacheio.py
View file @
b6e7f2a9
...
@@ -9,6 +9,22 @@ def is_hip() -> bool:
...
@@ -9,6 +9,22 @@ def is_hip() -> bool:
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
def
dcu_create_extend_after_decode_spec_info
(
verified_id
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
accept_lens
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
new_verified_id
:
torch
.
Tensor
,
bs
:
int
,
):
torch
.
ops
.
sgl_kernel
.
dcu_create_extend_after_decode_spec_info
(
verified_id
,
seq_lens
,
accept_lens
,
positions
,
new_verified_id
,
bs
,
)
def
dcu_alloc_extend_kernel
(
def
dcu_alloc_extend_kernel
(
pre_lens_ptr
:
torch
.
Tensor
,
pre_lens_ptr
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment