Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
46da9556
Commit
46da9556
authored
Nov 04, 2025
by
maxiao1
Browse files
Merge branch 'v0.5.4_dev_linhai' into 'v0.5.4_dev'
V0.5.4 dev linhai See merge request OpenDAS/sglang!9
parents
a9e0e668
ee775772
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
602 additions
and
6 deletions
+602
-6
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+5
-1
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+484
-0
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+2
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+2
-1
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+94
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+2
-1
python/sglang/srt/layers/attention/xpu_backend.py
python/sglang/srt/layers/attention/xpu_backend.py
+2
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-1
No files found.
python/sglang/srt/layers/attention/attention_registry.py
View file @
46da9556
...
...
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return
TritonAttnBackend
(
runner
)
@
register_attention_backend
(
"torch_native"
)
def
create_torch_native_backend
(
runner
):
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
...
...
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return
FlashMLABackend
(
runner
)
@
register_attention_backend
(
"dcu_mla"
)
def
create_dcu_mla_backend
(
runner
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
DCUMLABackend
return
DCUMLABackend
(
runner
)
@
register_attention_backend
(
"fa3"
)
def
create_flashattention_v3_backend
(
runner
):
...
...
python/sglang/srt/layers/attention/dcu_mla_backend.py
0 → 100644
View file @
46da9556
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
import
torch
import
triton
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
try
:
from
flash_mla
import
(
flash_mla_with_kvcache
,
flash_mla_with_kvcache_quantization
,
get_mla_metadata
)
_has_flash_mla
=
True
except
Exception
:
try
:
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
)
_has_flash_mla
=
False
except
Exception
:
raise
ImportError
(
"Can not import FlashMLA。Please perform the following operations to use flashmla:
\n
"
" pip install flash-mla
\n
"
" or
\n
"
" pip install vllm"
)
PAGE_SIZE
=
64
# 强制64
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInput
@
dataclass
class
VllmMLADecodeMetadata
:
flashmla_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
class
DCUMLABackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
"ModelRunner"
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_len_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
if
model_runner
.
server_args
.
page_size
!=
PAGE_SIZE
:
raise
ValueError
(
f
"dcu_mla backend requires page_size=
{
PAGE_SIZE
}
, "
f
"but got the
{
model_runner
.
server_args
.
page_size
}
"
)
self
.
num_q_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_nope_head_dim
=
model_runner
.
model_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
v_head_dim
=
model_runner
.
model_config
.
v_head_dim
self
.
kv_cache_dim
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
device
=
model_runner
.
device
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
forward_metadata
:
Union
[
VllmMLADecodeMetadata
]
=
None
self
.
skip_prefill
=
skip_prefill
if
not
skip_prefill
:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
self
.
flashattn_backend
=
FlashAttentionBackend
(
model_runner
,
skip_prefill
=
False
,
)
def
_build_decode_metadata
(
self
,
forward_batch
:
ForwardBatch
,
seq_lens
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
return
(
mla_metadata
,
num_splits
),
num_splits
,
block_kv_indices
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
# decode用flashmla
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
seq_lens
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
)
else
:
# prefill/extend用triton backend -> 改用flash attn
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata(forward_batch)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
,
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
block_kv_indices
is
None
:
cuda_graph_kv_indices
=
torch
.
full
(
(
max_bs
,
(
self
.
max_context_len
+
PAGE_SIZE
)
//
PAGE_SIZE
),
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
else
:
cuda_graph_kv_indices
=
block_kv_indices
if
self
.
num_draft_tokens
:
mla_metadata
,
num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
else
:
mla_metadata
,
num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
),
self
.
num_q_heads
,
1
,
)
self
.
cuda_graph_mla_metadata
=
mla_metadata
self
.
cuda_graph_num_splits
=
num_splits
self
.
cuda_graph_kv_indices
=
cuda_graph_kv_indices
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
"SpecInput"
],
):
if
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
[:
bs
+
1
],
self
.
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
],
)
elif
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
[:
bs
+
1
],
self
.
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
],
)
else
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
"SpecInput"
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
assert
seq_lens_cpu
is
not
None
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
.
flashmla_metadata
=
self
.
cuda_graph_mla_metadata
self
.
forward_metadata
.
num_splits
=
self
.
cuda_graph_num_splits
[:
bs
+
1
]
self
.
forward_metadata
.
block_kv_indices
=
self
.
cuda_graph_kv_indices
[
:
bs
,
:
max_seqlen_pad
]
elif
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
[:
bs
]
+
self
.
num_draft_tokens
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
+
self
.
num_draft_tokens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
.
flashmla_metadata
=
self
.
cuda_graph_mla_metadata
self
.
forward_metadata
.
num_splits
=
self
.
cuda_graph_num_splits
[:
bs
+
1
]
self
.
forward_metadata
.
block_kv_indices
=
self
.
cuda_graph_kv_indices
[
:
bs
,
:
max_seqlen_pad
]
else
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
_call_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
):
o
,
_
=
flash_mla_with_kvcache
(
q
=
reshape_q
,
k_cache
=
k_cache_reshaped
,
block_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
causal
=
True
,
)
return
o
def
_call_fp8_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
):
assert
_has_flash_mla
,
"FP8 KV cache 需要flash_mla包"
o
,
_
=
flash_mla_with_kvcache_quantization
(
q
=
reshape_q
,
k_cache
=
k_cache_reshaped
,
block_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
causal
=
True
,
is_fp8_kvcache
=
True
,
)
return
o
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
):
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
if
self
.
data_type
in
(
getattr
(
torch
,
"float8_e4m3fn"
,
None
),
getattr
(
torch
,
"float8_e4m3fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
sinks
=
None
,
):
if
(
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
# flash_attn不支持fp8,fp8无法正常执行extend
if
not
self
.
skip_prefill
:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return
self
.
flashattn_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
sinks
)
else
:
raise
RuntimeError
(
"skip prefill but use forward_extend"
)
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
if
self
.
data_type
in
(
getattr
(
torch
,
"float8_e4m3fn"
,
None
),
getattr
(
torch
,
"float8_e4m3fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
View file @
46da9556
...
...
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.sparse_flash_attn
import
(
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes_mergehead
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
46da9556
...
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
@
dataclass
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
0 → 100644
View file @
46da9556
from
flash_attn
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_interface
,
flash_attn_with_kvcache
as
flash_attn_with_kvcache_interface
)
from
typing
import
Optional
,
Union
import
torch
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
qv
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[
int
,
torch
.
Tensor
]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_k_new
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
rotary_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
scheduler_metadata
=
None
,
num_splits
=
0
,
# Can be tuned for speed
pack_gqa
=
None
,
# Can be tuned for speed
sm_margin
=
0
,
# Can be tuned if some SMs are used for communication
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
,
v_cache
=
v_cache
,
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
return_softmax_lse
,
num_splits
=
num_splits
,
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
=
None
,
max_seqlen_k
=
None
,
seqused_q
=
None
,
seqused_k
=
None
,
page_table
=
None
,
softmax_scale
=
None
,
causal
=
False
,
qv
=
None
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
attention_chunk
=
0
,
softcap
=
0.0
,
num_splits
=
1
,
pack_gqa
=
None
,
sm_margin
=
0
,
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_varlen_func_interface
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
)
\ No newline at end of file
python/sglang/srt/layers/attention/nsa_backend.py
View file @
46da9556
...
...
@@ -45,7 +45,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else
:
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_with_kvcache
@
dataclass
(
frozen
=
True
)
...
...
python/sglang/srt/layers/attention/xpu_backend.py
View file @
46da9556
...
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
class
XPUAttentionBackend
(
AttentionBackend
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
46da9556
...
...
@@ -165,6 +165,7 @@ MLA_ATTENTION_BACKENDS = [
"triton"
,
"flashmla"
,
"cutlass_mla"
,
"dcu_mla"
,
"trtllm_mla"
,
"ascend"
,
"nsa"
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
46da9556
...
...
@@ -342,6 +342,10 @@ def handle_attention_flashmla(attn, forward_batch):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"flashmla"
)
def
handle_attention_dcu_mla
(
attn
,
forward_batch
):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"dcu_mla"
)
def
handle_attention_cutlass_mla
(
attn
,
forward_batch
):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"cutlass_mla"
)
...
...
@@ -3577,6 +3581,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry
.
register
(
"flashinfer"
,
handle_attention_flashinfer
)
AttentionBackendRegistry
.
register
(
"fa3"
,
handle_attention_fa3
)
AttentionBackendRegistry
.
register
(
"flashmla"
,
handle_attention_flashmla
)
AttentionBackendRegistry
.
register
(
"dcu_mla"
,
handle_attention_dcu_mla
)
AttentionBackendRegistry
.
register
(
"cutlass_mla"
,
handle_attention_cutlass_mla
)
AttentionBackendRegistry
.
register
(
"fa4"
,
handle_attention_fa4
)
AttentionBackendRegistry
.
register
(
"trtllm_mla"
,
handle_attention_trtllm_mla
)
...
...
python/sglang/srt/server_args.py
View file @
46da9556
...
...
@@ -102,6 +102,8 @@ ATTENTION_BACKEND_CHOICES = [
"torch_native"
,
"flex_attention"
,
"nsa"
,
# ransplant from vllm
"dcu_mla"
,
# NVIDIA specific
"cutlass_mla"
,
"fa3"
,
...
...
@@ -1077,9 +1079,11 @@ class ServerArgs:
if
(
self
.
attention_backend
==
"flashmla"
or
self
.
decode_attention_backend
==
"flashmla"
or
self
.
attention_backend
==
"dcu_mla"
or
self
.
decode_attention_backend
==
"dcu_mla"
):
logger
.
warning
(
"FlashMLA only supports a page_size of 64, change page_size to 64."
"FlashMLA
/DCU MLA
only supports a page_size of 64, change page_size to 64."
)
self
.
page_size
=
64
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment