Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4c45697e
Commit
4c45697e
authored
Nov 18, 2025
by
shangxl
Browse files
fa3 support qwen.
parent
a55cb8b2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
143 additions
and
129 deletions
+143
-129
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+6
-9
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+119
-105
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+15
-14
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+2
-1
No files found.
python/sglang/srt/layers/attention/dcu_mla_backend.py
View file @
4c45697e
...
@@ -446,7 +446,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -446,7 +446,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
...
@@ -458,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -458,7 +458,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
)
.
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
...
@@ -468,7 +468,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -468,7 +468,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
)
.
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
)
)
...
@@ -487,9 +487,6 @@ class DCUMLABackend(AttentionBackend):
...
@@ -487,9 +487,6 @@ class DCUMLABackend(AttentionBackend):
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
sinks
=
None
,
sinks
=
None
,
):
):
if
save_kv_cache
:
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
if
(
if
(
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
...
@@ -517,7 +514,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -517,7 +514,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
num_draft_tokens
=
self
.
num_draft_tokens
if
self
.
num_draft_tokens
is
not
None
else
0
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
if
self
.
data_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
torch
.
float8_e5m2
,
torch
.
float8_e5m2fnuz
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
if
self
.
data_type
in
(
torch
.
float8_e4m3fnuz
,
torch
.
float8_e4m3fn
):
...
@@ -529,7 +526,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -529,7 +526,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
k_scale
,
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
kv_cache_dtype
=
kv_cache_dtype
,
...
@@ -539,7 +536,7 @@ class DCUMLABackend(AttentionBackend):
...
@@ -539,7 +536,7 @@ class DCUMLABackend(AttentionBackend):
reshape_q
,
reshape_q
,
k_cache_reshaped
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
(
forward_batch
.
seq_lens
+
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
layer
.
scaling
,
)
)
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
4c45697e
...
@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -329,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
is_hybrid
=
model_runner
.
is_hybrid
self
.
k_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
self
.
v_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
)
if
self
.
is_hybrid
:
if
self
.
is_hybrid
:
self
.
full_to_swa_index_mapping
=
(
self
.
full_to_swa_index_mapping
=
(
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
model_runner
.
token_to_kv_pool
.
full_to_swa_index_mapping
...
@@ -672,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -672,10 +674,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
# if not self.use_mla:
if
k_rope
is
None
:
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -694,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -694,7 +702,8 @@ class FlashAttentionBackend(AttentionBackend):
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
layer
.
sliding_window_size
is
not
None
and
layer
.
sliding_window_size
>
-
1
)
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
window_size
=
(
layer
.
sliding_window_size
,
0
)
if
is_swa
else
(
-
1
,
-
1
)
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...
@@ -778,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -778,55 +787,53 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
window_size
=
(
-
1
,
-
1
)
window_size
=
(
-
1
,
-
1
)
result
=
flash_attn_with_kvcache
(
if
forward_batch
.
attn_attend_prefix_cache
:
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
assert
not
get_global_server_args
().
disable_chunked_prefix_cache
k_cache
=
key_cache
,
# MHA for chunked prefix kv cache when running model with MLA
v_cache
=
value_cache
,
assert
forward_batch
.
prefix_chunk_idx
is
not
None
page_table
=
page_table
,
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
cache_seqlens
=
cache_seqlens
,
assert
forward_batch
.
prefix_chunk_max_seq_lens
is
not
None
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
chunk_idx
=
forward_batch
.
prefix_chunk_idx
max_seqlen_q
=
max_seqlen_q
,
assert
chunk_idx
>=
0
assert
forward_batch
.
mha_return_lse
output
=
flash_attn_varlen_func
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
cu_seqlens_k
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
causal
=
False
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
use_cascade_attn
,
return_softmax_lse
=
True
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
else
:
if
use_cascade_attn
:
output
=
flash_attn_varlen_func
(
o
,
softmax_lse
,
*
rest
=
result
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
flash_attn_with_kvcache
(
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
k_cache
=
key_cache
,
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
v_cache
=
value_cache
,
cu_seqlens_k
=
metadata
.
cu_seqlens_q
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
max_seqlen_q
=
metadata
.
max_seq_len_q
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens_int32
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
True
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
o
,
_
=
merge_state_v2_wrapper
(
if
forward_batch
.
mha_return_lse
:
o
,
output
,
lse
,
*
rest
=
output
softmax_lse
.
T
.
contiguous
(),
lse
=
torch
.
transpose
(
lse
,
0
,
1
).
contiguous
()
o_expand
,
return
output
,
lse
softmax_lse_expand
.
T
.
contiguous
(),
return
output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
)
else
:
o
=
result
else
:
else
:
if
(
if
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
forward_batch
.
attn_attend_prefix_cache
is
not
None
...
@@ -855,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -855,6 +862,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
max_seqlen_k
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
,
causal
=
False
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -869,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -869,6 +878,8 @@ class FlashAttentionBackend(AttentionBackend):
max_seqlen_k
=
metadata
.
max_seq_len_q
,
max_seqlen_k
=
metadata
.
max_seq_len_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
True
,
causal
=
True
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
return_softmax_lse
=
forward_batch
.
mha_return_lse
,
**
kwargs
,
**
kwargs
,
)
)
...
@@ -978,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -978,10 +989,16 @@ class FlashAttentionBackend(AttentionBackend):
if
not
layer
.
is_cross_attention
if
not
layer
.
is_cross_attention
else
forward_batch
.
encoder_out_cache_loc
else
forward_batch
.
encoder_out_cache_loc
)
)
# if not self.use_mla:
if
k_rope
is
None
:
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
layer
,
cache_loc
,
k
,
v
,
layer
.
k_scale
,
layer
.
v_scale
)
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
else
:
else
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
layer
,
...
@@ -1023,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1023,7 +1040,8 @@ class FlashAttentionBackend(AttentionBackend):
if
sinks
is
not
None
:
if
sinks
is
not
None
:
kwargs
[
"sinks"
]
=
sinks
kwargs
[
"sinks"
]
=
sinks
k_descale
,
v_descale
=
None
,
None
# k_descale, v_descale = None, None
k_descale
,
v_descale
=
self
.
k_scale
,
self
.
v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...
@@ -1037,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1037,7 +1055,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
k_rope
=
k_rope
.
to
(
self
.
kv_cache_dtype
)
if
k_rope
is
not
None
else
None
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
# Do multi-head attention
# Do multi-head attention
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
key_cache
,
value_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
layer
.
layer_id
)
)
...
@@ -1089,26 +1106,33 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1089,26 +1106,33 @@ class FlashAttentionBackend(AttentionBackend):
**
kwargs
,
**
kwargs
,
)
)
else
:
else
:
cu_seqlens_q
=
metadata
.
cu_seqlens_q
max_seqlen_q
=
metadata
.
max_seq_len_q
page_table
=
metadata
.
page_table
page_table
=
metadata
.
page_table
cache_seqlens
=
metadata
.
cache_seqlens_int32
cu_seqlens_k
=
metadata
.
cu_seqlens_k
cu_seqlens_k
=
metadata
.
cu_seqlens_k
max
_seqlen
_q
=
metadata
.
max
_seq
_
len
_q
cache
_seqlen
s
=
metadata
.
cache
_seqlen
s_int32
q_reshaped
=
q
.
contiguous
()
.
view
(
key_cache
=
key_cache
.
view
(
-
1
,
layer
.
tp_
q
_head_num
,
layer
.
head_dim
-
1
,
self
.
page_size
,
layer
.
tp_
k
_head_num
,
layer
.
head_dim
)
)
value_cache
=
value_cache
.
view
(
# Default: single-token self-attention
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
,
layer
.
head_dim
result
=
flash_attn_with_kvcache
(
)
q
=
q_reshaped
,
if
layer
.
is_cross_attention
:
k_cache
=
key_cache
,
page_table
=
metadata
.
encoder_page_table
v_cache
=
value_cache
,
cache_seqlens
=
metadata
.
encoder_lens_int32
page_table
=
page_table
,
cu_seqlens_k
=
metadata
.
encoder_cu_seqlens_k
cache_seqlens
=
cache_seqlens
,
window_size
=
(
-
1
,
-
1
)
cu_seqlens_q
=
metadata
.
cu_seqlens_q
,
if
max_seqlen_q
>
1
:
cu_seqlens_k_new
=
cu_seqlens_k
,
result
=
flash_attn_varlen_func
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
view
(
q
.
dtype
),
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
view
(
q
.
dtype
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
False
if
use_cascade_attn
else
causal
,
causal
=
True
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
...
@@ -1117,36 +1141,26 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -1117,36 +1141,26 @@ class FlashAttentionBackend(AttentionBackend):
num_splits
=
self
.
num_splits
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
if
use_cascade_attn
:
else
:
o
,
softmax_lse
,
*
rest
=
result
result
=
flash_attn_with_kvcache
(
o_expand
,
softmax_lse_expand
,
*
rest_expand
=
(
q
=
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
flash_attn_with_kvcache
(
q
=
q_reshaped
,
k_cache
=
key_cache
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
v_cache
=
value_cache
,
page_table
=
self
.
forward_metadata_spec_decode_expand
.
page_table
,
page_table
=
page_table
,
cache_seqlens
=
self
.
forward_metadata_spec_decode_expand
.
cache_seqlens
_int32
,
cache_seqlens
=
cache_seqlens
,
cu_seqlens_q
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k_new
=
self
.
forward_metadata_spec_decode_expand
.
cu_seqlens_k
,
cu_seqlens_k_new
=
cu_seqlens_k
if
not
use_local_attn
else
None
,
max_seqlen_q
=
self
.
forward_metadata_spec_decode_expand
.
max_seq
_
len_q
,
max_seqlen_q
=
max_seqlen_q
,
softmax_scale
=
layer
.
scaling
,
softmax_scale
=
layer
.
scaling
,
causal
=
Fals
e
,
causal
=
Tru
e
,
window_size
=
window_size
,
window_size
=
window_size
,
softcap
=
layer
.
logit_cap
,
softcap
=
layer
.
logit_cap
,
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
return_softmax_lse
=
True
,
return_softmax_lse
=
use_cascade_attn
,
num_splits
=
self
.
num_splits
,
num_splits
=
self
.
num_splits
,
**
kwargs
,
**
kwargs
,
)
)
)
o
,
_
=
merge_state_v2
(
o
,
softmax_lse
.
T
.
contiguous
(),
o_expand
,
softmax_lse_expand
.
T
.
contiguous
(),
)
else
:
o
=
result
o
=
result
else
:
else
:
# Do absorbed multi-latent attention
# Do absorbed multi-latent attention
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
View file @
4c45697e
...
@@ -42,8 +42,8 @@ def flash_attn_with_kvcache(
...
@@ -42,8 +42,8 @@ def flash_attn_with_kvcache(
):
):
return
flash_attn_with_kvcache_interface
(
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
,
k_cache
=
k_cache
.
view
(
q
.
dtype
)
,
v_cache
=
v_cache
,
v_cache
=
v_cache
.
view
(
q
.
dtype
)
,
block_table
=
page_table
,
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
...
@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
...
@@ -83,8 +83,8 @@ def flash_attn_varlen_func(
):
):
return
flash_attn_varlen_func_interface
(
return
flash_attn_varlen_func_interface
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
.
view
(
q
.
dtype
)
,
v
=
v
,
v
=
v
.
view
(
q
.
dtype
)
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
...
@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
...
@@ -92,4 +92,5 @@ def flash_attn_varlen_func(
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
causal
=
causal
,
return_attn_probs
=
return_softmax_lse
,
return_attn_probs
=
return_softmax_lse
,
softcap
=
softcap
,
)
)
\ No newline at end of file
python/sglang/srt/managers/schedule_batch.py
View file @
4c45697e
...
@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1193,6 +1193,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
extend_num_tokens
=
extend_num_tokens
self
.
extend_num_tokens
=
extend_num_tokens
self
.
loc_tensor
=
torch
.
tensor
([
-
1
],
device
=
self
.
device
)
# Allocate memory
# Allocate memory
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
...
...
python/sglang/srt/mem_cache/common.py
View file @
4c45697e
...
@@ -356,7 +356,8 @@ def alloc_for_extend(
...
@@ -356,7 +356,8 @@ def alloc_for_extend(
else
:
else
:
# Paged allocation - build last_loc
# Paged allocation - build last_loc
last_loc
=
[
last_loc
=
[
(
t
[
-
1
:]
if
len
(
t
)
>
0
else
torch
.
tensor
([
-
1
],
device
=
batch
.
device
))
# (t[-1:] if len(t) > 0 else torch.tensor([-1], device=batch.device))
(
t
[
-
1
:]
if
len
(
t
)
>
0
else
batch
.
loc_tensor
)
for
t
in
prefix_tensors
for
t
in
prefix_tensors
]
]
out_cache_loc
=
alloc_paged_token_slots_extend
(
out_cache_loc
=
alloc_paged_token_slots_extend
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment