Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ee775772
Commit
ee775772
authored
Nov 04, 2025
by
linhai1
Committed by
maxiao1
Nov 04, 2025
Browse files
V0.5.4 dev linhai
parent
a9e0e668
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
602 additions
and
6 deletions
+602
-6
python/sglang/srt/layers/attention/attention_registry.py
python/sglang/srt/layers/attention/attention_registry.py
+5
-1
python/sglang/srt/layers/attention/dcu_mla_backend.py
python/sglang/srt/layers/attention/dcu_mla_backend.py
+484
-0
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
...srt/layers/attention/dual_chunk_flashattention_backend.py
+2
-1
python/sglang/srt/layers/attention/flashattention_backend.py
python/sglang/srt/layers/attention/flashattention_backend.py
+2
-1
python/sglang/srt/layers/attention/flashattention_interface.py
...n/sglang/srt/layers/attention/flashattention_interface.py
+94
-0
python/sglang/srt/layers/attention/nsa_backend.py
python/sglang/srt/layers/attention/nsa_backend.py
+2
-1
python/sglang/srt/layers/attention/xpu_backend.py
python/sglang/srt/layers/attention/xpu_backend.py
+2
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-1
No files found.
python/sglang/srt/layers/attention/attention_registry.py
View file @
ee775772
...
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
...
@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return
TritonAttnBackend
(
runner
)
return
TritonAttnBackend
(
runner
)
@
register_attention_backend
(
"torch_native"
)
@
register_attention_backend
(
"torch_native"
)
def
create_torch_native_backend
(
runner
):
def
create_torch_native_backend
(
runner
):
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
...
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
...
@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return
FlashMLABackend
(
runner
)
return
FlashMLABackend
(
runner
)
@
register_attention_backend
(
"dcu_mla"
)
def
create_dcu_mla_backend
(
runner
):
from
sglang.srt.layers.attention.dcu_mla_backend
import
DCUMLABackend
return
DCUMLABackend
(
runner
)
@
register_attention_backend
(
"fa3"
)
@
register_attention_backend
(
"fa3"
)
def
create_flashattention_v3_backend
(
runner
):
def
create_flashattention_v3_backend
(
runner
):
...
...
python/sglang/srt/layers/attention/dcu_mla_backend.py
0 → 100644
View file @
ee775772
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
import
torch
import
triton
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashmla_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
try
:
from
flash_mla
import
(
flash_mla_with_kvcache
,
flash_mla_with_kvcache_quantization
,
get_mla_metadata
)
_has_flash_mla
=
True
except
Exception
:
try
:
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
)
_has_flash_mla
=
False
except
Exception
:
raise
ImportError
(
"Can not import FlashMLA。Please perform the following operations to use flashmla:
\n
"
" pip install flash-mla
\n
"
" or
\n
"
" pip install vllm"
)
PAGE_SIZE
=
64
# 强制64
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInput
@
dataclass
class
VllmMLADecodeMetadata
:
flashmla_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
num_splits
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
class
DCUMLABackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
"ModelRunner"
,
skip_prefill
:
bool
=
False
,
kv_indptr_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
kv_last_page_len_buf
:
Optional
[
torch
.
Tensor
]
=
None
,
):
super
().
__init__
()
if
model_runner
.
server_args
.
page_size
!=
PAGE_SIZE
:
raise
ValueError
(
f
"dcu_mla backend requires page_size=
{
PAGE_SIZE
}
, "
f
"but got the
{
model_runner
.
server_args
.
page_size
}
"
)
self
.
num_q_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
kv_lora_rank
=
model_runner
.
model_config
.
kv_lora_rank
self
.
qk_nope_head_dim
=
model_runner
.
model_config
.
qk_nope_head_dim
self
.
qk_rope_head_dim
=
model_runner
.
model_config
.
qk_rope_head_dim
self
.
v_head_dim
=
model_runner
.
model_config
.
v_head_dim
self
.
kv_cache_dim
=
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
device
=
model_runner
.
device
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
num_draft_tokens
=
model_runner
.
server_args
.
speculative_num_draft_tokens
self
.
forward_metadata
:
Union
[
VllmMLADecodeMetadata
]
=
None
self
.
skip_prefill
=
skip_prefill
if
not
skip_prefill
:
# 先用triton backend,后面考虑替换
# from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
# self.triton_backend = TritonAttnBackend(
# model_runner,
# skip_prefill=False,
# kv_indptr_buf=kv_indptr_buf,
# )
# prefill改用flash attn
from
sglang.srt.layers.attention.flashattention_backend
import
FlashAttentionBackend
self
.
flashattn_backend
=
FlashAttentionBackend
(
model_runner
,
skip_prefill
=
False
,
)
def
_build_decode_metadata
(
self
,
forward_batch
:
ForwardBatch
,
seq_lens
:
torch
.
Tensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
]:
bs
=
forward_batch
.
batch_size
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
# 参考vllm官方博客分页
block_kv_indices
=
torch
.
full
(
(
bs
,
max_seqlen_pad
),
-
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
seq_lens
,
None
,
block_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
max_seqlen_pad
,
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_q_heads
,
1
)
return
(
mla_metadata
,
num_splits
),
num_splits
,
block_kv_indices
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
# decode用flashmla
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
forward_batch
.
seq_lens
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
seq_lens
=
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
(
mla_metadata
,
num_splits
),
num_splits_t
,
block_kv_indices
=
(
self
.
_build_decode_metadata
(
forward_batch
,
seq_lens
)
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
mla_metadata
,
num_splits_t
,
block_kv_indices
)
else
:
# prefill/extend用triton backend -> 改用flash attn
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata(forward_batch)
self
.
flashattn_backend
.
init_forward_metadata
(
forward_batch
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
max_num_tokens
:
int
,
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
block_kv_indices
is
None
:
cuda_graph_kv_indices
=
torch
.
full
(
(
max_bs
,
(
self
.
max_context_len
+
PAGE_SIZE
)
//
PAGE_SIZE
),
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
else
:
cuda_graph_kv_indices
=
block_kv_indices
if
self
.
num_draft_tokens
:
mla_metadata
,
num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
,
)
else
:
mla_metadata
,
num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_bs
,
dtype
=
torch
.
int32
,
device
=
cuda_graph_kv_indices
.
device
),
self
.
num_q_heads
,
1
,
)
self
.
cuda_graph_mla_metadata
=
mla_metadata
self
.
cuda_graph_num_splits
=
num_splits
self
.
cuda_graph_kv_indices
=
cuda_graph_kv_indices
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
"SpecInput"
],
):
if
forward_mode
.
is_decode_or_idle
():
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
[:
bs
+
1
],
self
.
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
],
)
elif
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
+
self
.
num_draft_tokens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
=
VllmMLADecodeMetadata
(
self
.
cuda_graph_mla_metadata
,
self
.
cuda_graph_num_splits
[:
bs
+
1
],
self
.
cuda_graph_kv_indices
[:
bs
,
:
max_seqlen_pad
],
)
else
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_capture_cuda_graph(
# bs,
# num_tokens,
# req_pool_indices,
# seq_lens,
# encoder_lens,
# forward_mode,
# spec_info,
# )
self
.
flashattn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_tokens
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_mode
,
spec_info
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
"SpecInput"
],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
if
forward_mode
.
is_decode_or_idle
():
assert
seq_lens_cpu
is
not
None
seq_lens
=
seq_lens
[:
bs
]
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
num_q_heads
=
self
.
num_q_heads
*
(
self
.
num_draft_tokens
or
1
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
.
flashmla_metadata
=
self
.
cuda_graph_mla_metadata
self
.
forward_metadata
.
num_splits
=
self
.
cuda_graph_num_splits
[:
bs
+
1
]
self
.
forward_metadata
.
block_kv_indices
=
self
.
cuda_graph_kv_indices
[
:
bs
,
:
max_seqlen_pad
]
elif
forward_mode
.
is_target_verify
():
seq_lens
=
seq_lens
[:
bs
]
+
self
.
num_draft_tokens
seq_lens_cpu
=
seq_lens_cpu
[:
bs
]
+
self
.
num_draft_tokens
max_seqlen_pad
=
triton
.
cdiv
(
seq_lens_cpu
.
max
().
item
(),
PAGE_SIZE
)
create_flashmla_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
,
None
,
self
.
cuda_graph_kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
cuda_graph_kv_indices
.
stride
(
0
),
)
mla_metadata
,
num_splits
=
get_mla_metadata
(
seq_lens
.
to
(
torch
.
int32
),
self
.
num_draft_tokens
*
self
.
num_q_heads
,
1
)
self
.
cuda_graph_mla_metadata
.
copy_
(
mla_metadata
)
self
.
cuda_graph_num_splits
[:
bs
+
1
].
copy_
(
num_splits
)
self
.
forward_metadata
.
flashmla_metadata
=
self
.
cuda_graph_mla_metadata
self
.
forward_metadata
.
num_splits
=
self
.
cuda_graph_num_splits
[:
bs
+
1
]
self
.
forward_metadata
.
block_kv_indices
=
self
.
cuda_graph_kv_indices
[
:
bs
,
:
max_seqlen_pad
]
else
:
if
not
self
.
skip_prefill
:
# self.triton_backend.init_forward_metadata_replay_cuda_graph(
# bs,
# req_pool_indices,
# seq_lens,
# seq_lens_sum,
# encoder_lens,
# forward_mode,
# spec_info,
# seq_lens_cpu,
# )
self
.
flashattn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
encoder_lens
,
forward_mode
,
spec_info
,
seq_lens_cpu
,
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
_call_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
):
o
,
_
=
flash_mla_with_kvcache
(
q
=
reshape_q
,
k_cache
=
k_cache_reshaped
,
block_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
causal
=
True
,
)
return
o
def
_call_fp8_decode
(
self
,
reshape_q
:
torch
.
Tensor
,
k_cache_reshaped
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
scaling
:
float
):
assert
_has_flash_mla
,
"FP8 KV cache 需要flash_mla包"
o
,
_
=
flash_mla_with_kvcache_quantization
(
q
=
reshape_q
,
k_cache
=
k_cache_reshaped
,
block_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
self
.
forward_metadata
.
flashmla_metadata
,
num_splits
=
self
.
forward_metadata
.
num_splits
,
softmax_scale
=
scaling
,
causal
=
True
,
is_fp8_kvcache
=
True
,
)
return
o
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
):
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
if
self
.
data_type
in
(
getattr
(
torch
,
"float8_e4m3fn"
,
None
),
getattr
(
torch
,
"float8_e4m3fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
layer
.
scaling
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
"RadixAttention"
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
sinks
=
None
,
):
if
(
forward_batch
.
forward_mode
==
ForwardMode
.
EXTEND
or
forward_batch
.
forward_mode
==
ForwardMode
.
DRAFT_EXTEND
):
# flash_attn不支持fp8,fp8无法正常执行extend
if
not
self
.
skip_prefill
:
# return self.triton_backend.forward_extend(
# q, k, v, layer, forward_batch, save_kv_cache, sinks
# )
return
self
.
flashattn_backend
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
sinks
)
else
:
raise
RuntimeError
(
"skip prefill but use forward_extend"
)
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
reshape_q
=
q
.
view
(
bs
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_cache_reshaped
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
1
,
self
.
kv_cache_dim
)
if
self
.
data_type
in
(
getattr
(
torch
,
"float8_e4m3fn"
,
None
),
getattr
(
torch
,
"float8_e4m3fnuz"
,
None
),
getattr
(
torch
,
"float8_e5m2"
,
None
),
getattr
(
torch
,
"float8_e5m2fnuz"
,
None
),
):
o
=
self
.
_call_fp8_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
)
else
:
o
=
self
.
_call_decode
(
reshape_q
,
k_cache_reshaped
,
self
.
forward_metadata
.
block_kv_indices
[:
bs
],
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
).
to
(
torch
.
int32
),
layer
.
scaling
,
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
python/sglang/srt/layers/attention/dual_chunk_flashattention_backend.py
View file @
ee775772
...
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
...
@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
sgl_kernel.sparse_flash_attn
import
(
from
sgl_kernel.sparse_flash_attn
import
(
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes
,
convert_vertical_slash_indexes_mergehead
,
convert_vertical_slash_indexes_mergehead
,
...
...
python/sglang/srt/layers/attention/flashattention_backend.py
View file @
ee775772
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
@
dataclass
@
dataclass
...
...
python/sglang/srt/layers/attention/flashattention_interface.py
0 → 100644
View file @
ee775772
from
flash_attn
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_interface
,
flash_attn_with_kvcache
as
flash_attn_with_kvcache_interface
)
from
typing
import
Optional
,
Union
import
torch
def
flash_attn_with_kvcache
(
q
,
k_cache
,
v_cache
,
k
=
None
,
v
=
None
,
qv
=
None
,
rotary_cos
=
None
,
rotary_sin
=
None
,
cache_seqlens
:
Optional
[
Union
[
int
,
torch
.
Tensor
]]
=
None
,
cache_batch_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_leftpad
:
Optional
[
torch
.
Tensor
]
=
None
,
page_table
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_k_new
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
rotary_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
# -1 means infinite context window
attention_chunk
:
Optional
[
int
]
=
None
,
softcap
=
0.0
,
# 0.0 means deactivated
rotary_interleaved
=
True
,
scheduler_metadata
=
None
,
num_splits
=
0
,
# Can be tuned for speed
pack_gqa
=
None
,
# Can be tuned for speed
sm_margin
=
0
,
# Can be tuned if some SMs are used for communication
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_with_kvcache_interface
(
q
=
q
.
contiguous
().
view
(
-
1
,
max_seqlen_q
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
]),
k_cache
=
k_cache
,
v_cache
=
v_cache
,
block_table
=
page_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
return_softmax_lse
=
return_softmax_lse
,
num_splits
=
num_splits
,
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
=
None
,
max_seqlen_k
=
None
,
seqused_q
=
None
,
seqused_k
=
None
,
page_table
=
None
,
softmax_scale
=
None
,
causal
=
False
,
qv
=
None
,
q_descale
=
None
,
k_descale
=
None
,
v_descale
=
None
,
window_size
=
(
-
1
,
-
1
),
attention_chunk
=
0
,
softcap
=
0.0
,
num_splits
=
1
,
pack_gqa
=
None
,
sm_margin
=
0
,
return_softmax_lse
=
False
,
sinks
=
None
,
ver
=
3
,
):
return
flash_attn_varlen_func_interface
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
)
\ No newline at end of file
python/sglang/srt/layers/attention/nsa_backend.py
View file @
ee775772
...
@@ -45,7 +45,8 @@ if _is_hip:
...
@@ -45,7 +45,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
)
else
:
else
:
from
sgl_kernel.flash_attn
import
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_with_kvcache
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
...
python/sglang/srt/layers/attention/xpu_backend.py
View file @
ee775772
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
...
@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel
import
merge_state_v2
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from
sglang.srt.layers.attention.flashattention_interface
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
class
XPUAttentionBackend
(
AttentionBackend
):
class
XPUAttentionBackend
(
AttentionBackend
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
ee775772
...
@@ -165,6 +165,7 @@ MLA_ATTENTION_BACKENDS = [
...
@@ -165,6 +165,7 @@ MLA_ATTENTION_BACKENDS = [
"triton"
,
"triton"
,
"flashmla"
,
"flashmla"
,
"cutlass_mla"
,
"cutlass_mla"
,
"dcu_mla"
,
"trtllm_mla"
,
"trtllm_mla"
,
"ascend"
,
"ascend"
,
"nsa"
,
"nsa"
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
ee775772
...
@@ -342,6 +342,10 @@ def handle_attention_flashmla(attn, forward_batch):
...
@@ -342,6 +342,10 @@ def handle_attention_flashmla(attn, forward_batch):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"flashmla"
)
return
_handle_attention_backend
(
attn
,
forward_batch
,
"flashmla"
)
def
handle_attention_dcu_mla
(
attn
,
forward_batch
):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"dcu_mla"
)
def
handle_attention_cutlass_mla
(
attn
,
forward_batch
):
def
handle_attention_cutlass_mla
(
attn
,
forward_batch
):
return
_handle_attention_backend
(
attn
,
forward_batch
,
"cutlass_mla"
)
return
_handle_attention_backend
(
attn
,
forward_batch
,
"cutlass_mla"
)
...
@@ -3577,6 +3581,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
...
@@ -3577,6 +3581,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry
.
register
(
"flashinfer"
,
handle_attention_flashinfer
)
AttentionBackendRegistry
.
register
(
"flashinfer"
,
handle_attention_flashinfer
)
AttentionBackendRegistry
.
register
(
"fa3"
,
handle_attention_fa3
)
AttentionBackendRegistry
.
register
(
"fa3"
,
handle_attention_fa3
)
AttentionBackendRegistry
.
register
(
"flashmla"
,
handle_attention_flashmla
)
AttentionBackendRegistry
.
register
(
"flashmla"
,
handle_attention_flashmla
)
AttentionBackendRegistry
.
register
(
"dcu_mla"
,
handle_attention_dcu_mla
)
AttentionBackendRegistry
.
register
(
"cutlass_mla"
,
handle_attention_cutlass_mla
)
AttentionBackendRegistry
.
register
(
"cutlass_mla"
,
handle_attention_cutlass_mla
)
AttentionBackendRegistry
.
register
(
"fa4"
,
handle_attention_fa4
)
AttentionBackendRegistry
.
register
(
"fa4"
,
handle_attention_fa4
)
AttentionBackendRegistry
.
register
(
"trtllm_mla"
,
handle_attention_trtllm_mla
)
AttentionBackendRegistry
.
register
(
"trtllm_mla"
,
handle_attention_trtllm_mla
)
...
...
python/sglang/srt/server_args.py
View file @
ee775772
...
@@ -102,6 +102,8 @@ ATTENTION_BACKEND_CHOICES = [
...
@@ -102,6 +102,8 @@ ATTENTION_BACKEND_CHOICES = [
"torch_native"
,
"torch_native"
,
"flex_attention"
,
"flex_attention"
,
"nsa"
,
"nsa"
,
# ransplant from vllm
"dcu_mla"
,
# NVIDIA specific
# NVIDIA specific
"cutlass_mla"
,
"cutlass_mla"
,
"fa3"
,
"fa3"
,
...
@@ -1077,9 +1079,11 @@ class ServerArgs:
...
@@ -1077,9 +1079,11 @@ class ServerArgs:
if
(
if
(
self
.
attention_backend
==
"flashmla"
self
.
attention_backend
==
"flashmla"
or
self
.
decode_attention_backend
==
"flashmla"
or
self
.
decode_attention_backend
==
"flashmla"
or
self
.
attention_backend
==
"dcu_mla"
or
self
.
decode_attention_backend
==
"dcu_mla"
):
):
logger
.
warning
(
logger
.
warning
(
"FlashMLA only supports a page_size of 64, change page_size to 64."
"FlashMLA
/DCU MLA
only supports a page_size of 64, change page_size to 64."
)
)
self
.
page_size
=
64
self
.
page_size
=
64
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment