Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f92b729d
Unverified
Commit
f92b729d
authored
Aug 26, 2025
by
ZhengdQin
Committed by
GitHub
Aug 25, 2025
Browse files
[new feat] ascend backend support fia fusion kernel (#8328)
Co-authored-by:
Even Zhou
<
even.y.zhou@outlook.com
>
parent
e2e378ca
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
511 additions
and
130 deletions
+511
-130
.github/workflows/pr-test-npu.yml
.github/workflows/pr-test-npu.yml
+3
-3
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+218
-111
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+1
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+73
-14
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+9
-1
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
+103
-0
test/srt/ascend/test_ascend_mla_w8a8int8.py
test/srt/ascend/test_ascend_mla_w8a8int8.py
+1
-0
test/srt/ascend/test_ascend_tp2_fia_bf16.py
test/srt/ascend/test_ascend_tp2_fia_bf16.py
+101
-0
test/srt/run_suite.py
test/srt/run_suite.py
+2
-0
No files found.
.github/workflows/pr-test-npu.yml
View file @
f92b729d
...
@@ -47,7 +47,7 @@ jobs:
...
@@ -47,7 +47,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
3
0
timeout-minutes
:
6
0
env
:
env
:
SGLANG_USE_MODELSCOPE
:
true
SGLANG_USE_MODELSCOPE
:
true
SGLANG_IS_IN_CI
:
true
SGLANG_IS_IN_CI
:
true
...
@@ -82,7 +82,7 @@ jobs:
...
@@ -82,7 +82,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
3
0
timeout-minutes
:
9
0
env
:
env
:
SGLANG_USE_MODELSCOPE
:
true
SGLANG_USE_MODELSCOPE
:
true
SGLANG_IS_IN_CI
:
true
SGLANG_IS_IN_CI
:
true
...
@@ -117,7 +117,7 @@ jobs:
...
@@ -117,7 +117,7 @@ jobs:
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
curl -o /tmp/test.jsonl -L https://gh-proxy.test.osinfra.cn/https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
3
0
timeout-minutes
:
6
0
env
:
env
:
SGLANG_USE_MODELSCOPE
:
true
SGLANG_USE_MODELSCOPE
:
true
SGLANG_IS_IN_CI
:
true
SGLANG_IS_IN_CI
:
true
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
f92b729d
...
@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...
@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.layers.radix_attention
import
AttentionType
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
get_bool_env_var
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
import
os
import
numpy
as
np
@
dataclass
@
dataclass
class
ForwardMetadata
:
class
ForwardMetadata
:
...
@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
super
().
__init__
()
super
().
__init__
()
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
self
.
gen_attention_mask
(
128
,
model_runner
.
dtype
)
self
.
page_size
=
model_runner
.
page_size
self
.
page_size
=
model_runner
.
page_size
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
use_mla
=
model_runner
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
if
self
.
use_mla
:
if
self
.
use_mla
:
...
@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
graph_mode
=
False
self
.
graph_mode
=
False
self
.
use_fia
=
get_bool_env_var
(
"ASCEND_USE_FIA"
,
"False"
)
if
not
self
.
use_fia
:
self
.
gen_attention_mask
(
128
,
model_runner
.
dtype
)
mask_length
=
2048
self
.
fia_mask
=
~
torch
.
tril
(
torch
.
ones
(
(
mask_length
,
mask_length
),
dtype
=
torch
.
bool
,
device
=
model_runner
.
device
,
)
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init the metadata for a forward pass."""
"""Init the metadata for a forward pass."""
...
@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
forward_batch
.
extend_seq_lens
.
cpu
().
int
()
forward_batch
.
extend_seq_lens
.
cpu
().
int
()
)
)
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_cpu_int
=
forward_batch
.
seq_lens_cpu
.
int
()
self
.
forward_metadata
.
seq_lens_list_cumsum
=
np
.
cumsum
(
forward_batch
.
extend_seq_lens_cpu
)
self
.
graph_mode
=
False
self
.
graph_mode
=
False
...
@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
):
):
if
save_kv_cache
:
if
not
self
.
use_mla
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
if
not
self
.
use_mla
:
if
self
.
use_fia
:
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
"""FIA will support multi-bs in the later version of CANN"""
output
=
torch
.
empty
(
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
(
query
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
attn_output
=
torch
.
empty
(
dtype
=
query
.
dtype
,
(
q
.
size
(
0
),
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
device
=
query
.
device
,
device
=
q
.
device
,
)
dtype
=
q
.
dtype
,
)
q_len_offset
=
0
for
q_len
in
forward_batch
.
extend_seq_lens_cpu
:
attn_output
[
q_len_offset
:
q_len_offset
+
q_len
]
=
(
torch
.
ops
.
npu
.
npu_fused_infer_attention_score
(
q
[
None
,
q_len_offset
:
q_len_offset
+
q_len
],
k
[
None
,
q_len_offset
:
q_len_offset
+
q_len
],
v
[
None
,
q_len_offset
:
q_len_offset
+
q_len
],
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSND"
,
# todo, TND not supports q_heads!=k_heads
atten_mask
=
self
.
fia_mask
.
unsqueeze
(
0
),
sparse_mode
=
3
,
scale
=
layer
.
scaling
,
next_tokens
=
0
,
)[
0
]
)
q_len_offset
+=
q_len
attn_output
=
attn_output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
torch_npu
.
_npu_flash_attention_qlens
(
query
=
query
,
key_cache
=
k_cache
,
value_cache
=
v_cache
,
mask
=
self
.
mask
,
block_table
=
self
.
forward_metadata
.
block_tables
,
seq_len
=
self
.
forward_metadata
.
extend_seq_lens_cpu_int
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
scale_value
=
layer
.
scaling
,
num_heads
=
layer
.
tp_q_head_num
,
num_kv_heads
=
layer
.
tp_k_head_num
,
out
=
output
,
)
return
output
else
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
else
:
o
=
torch
.
empty_like
(
q
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
attn_output
=
torch
.
empty
(
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
(
query
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
dtype
=
query
.
dtype
,
q_
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
device
=
query
.
device
,
o_
=
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
)
causal
=
True
torch_npu
.
_npu_flash_attention_qlens
(
if
(
query
=
query
,
layer
.
is_cross_attention
key_cache
=
k_cache
,
or
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
value_cache
=
v_cache
,
):
mask
=
self
.
mask
,
causal
=
False
block_table
=
self
.
forward_metadata
.
block_tables
,
seq_len
=
self
.
forward_metadata
.
extend_seq_lens_cpu_int
,
self
.
native_attn
.
_run_sdpa_forward_extend
(
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
q_
,
scale_value
=
layer
.
scaling
,
o_
,
num_heads
=
layer
.
tp_q_head_num
,
k_cache
.
view
(
num_kv_heads
=
layer
.
tp_k_head_num
,
-
1
,
layer
.
tp_k_head_num
,
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
out
=
attn_output
,
),
)
v_cache
.
view
(
-
1
,
layer
.
tp_v_head_num
,
self
.
kv_lora_rank
),
else
:
forward_batch
.
req_to_token_pool
.
req_to_token
,
assert
(
forward_batch
.
req_pool_indices
,
layer
.
qk_head_dim
!=
layer
.
v_head_dim
forward_batch
.
seq_lens
,
),
"FIA only supports qk_head_dim != v_head_dim"
forward_batch
.
extend_prefix_lens
,
q_nope
,
q_rope
=
q
.
split
([
layer
.
v_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
forward_batch
.
extend_seq_lens
,
k_nope
,
k_rope
=
k
.
split
([
layer
.
v_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
scaling
=
layer
.
scaling
,
enable_gqa
=
use_gqa
,
attn_output
,
_
=
torch
.
ops
.
npu
.
npu_fused_infer_attention_score
(
causal
=
causal
,
q_nope
,
k_nope
,
v
,
query_rope
=
q_rope
,
key_rope
=
k_rope
,
num_heads
=
layer
.
tp_q_head_num
,
input_layout
=
"TND"
,
atten_mask
=
self
.
fia_mask
,
sparse_mode
=
3
,
actual_seq_lengths
=
self
.
forward_metadata
.
seq_lens_list_cumsum
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_list_cumsum
,
scale
=
layer
.
scaling
,
next_tokens
=
0
,
)
)
return
o
return
attn_output
def
forward_decode
(
def
forward_decode
(
self
,
self
,
...
@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
:
bool
=
False
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
num_tokens
=
q
.
shape
[
0
]
if
self
.
graph_mode
:
if
self
.
graph_mode
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
layer
.
layer_id
...
@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
layer
.
layer_id
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
query
=
q
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
query
=
q
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
num_tokens
=
query
.
shape
[
0
]
workspace
=
(
workspace
=
(
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
query
,
query
,
...
@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
)
)
)
)
output
=
torch
.
empty
(
attn_
output
=
torch
.
empty
(
(
num_tokens
,
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
(
num_tokens
,
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
dtype
=
q
.
dtype
,
dtype
=
q
.
dtype
,
device
=
q
.
device
,
device
=
q
.
device
,
...
@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
scale
=
layer
.
scaling
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
workspace
=
workspace
,
workspace
=
workspace
,
out
=
[
output
,
softmax_lse
],
out
=
[
attn_
output
,
softmax_lse
],
)
)
else
:
else
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
layer
.
layer_id
)
)
if
self
.
use_fia
:
attn_output
,
_
=
torch
.
ops
.
npu
.
npu_fused_infer_attention_score
(
q
.
view
(
forward_batch
.
batch_size
,
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
,
),
k_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
layer
.
qk_head_dim
),
v_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
qk_head_dim
),
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSND"
,
atten_mask
=
None
,
block_size
=
self
.
page_size
,
block_table
=
self
.
forward_metadata
.
block_tables
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
scale
=
layer
.
scaling
,
)
else
:
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
attn_output
=
torch
.
empty
(
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
dtype
=
query
.
dtype
,
device
=
query
.
device
,
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
torch_npu
.
_npu_paged_attention
(
num_tokens
=
query
.
shape
[
0
]
query
=
query
,
output
=
torch
.
empty
(
key_cache
=
k_cache
,
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
value_cache
=
v_cache
,
dtype
=
query
.
dtype
,
num_heads
=
layer
.
tp_q_head_num
,
device
=
query
.
device
,
num_kv_heads
=
layer
.
tp_k_head_num
,
scale_value
=
layer
.
scaling
,
block_table
=
self
.
forward_metadata
.
block_tables
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
out
=
attn_output
,
)
return
attn_output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
k_rope
)
)
num_tokens
=
q
.
shape
[
0
]
torch_npu
.
_npu_paged_attention
(
kv_c
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
query
=
query
,
k_pe
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
key_cache
=
k_cache
,
value_cache
=
v_cache
,
if
(
self
.
graph_mode
or
self
.
use_fia
)
and
(
layer
.
tp_q_head_num
//
layer
.
tp_k_head_num
)
>=
8
:
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
kv_c
=
kv_c
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
self
.
kv_lora_rank
)
k_pe
=
k_pe
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
self
.
qk_rope_head_dim
)
q
=
q
.
view
(
forward_batch
.
batch_size
,
-
1
,
layer
.
tp_q_head_num
,
self
.
kv_lora_rank
)
q_rope
=
q_rope
.
view
(
forward_batch
.
batch_size
,
-
1
,
layer
.
tp_q_head_num
,
self
.
qk_rope_head_dim
,
)
attn_output
,
_
=
torch
.
ops
.
npu
.
npu_fused_infer_attention_score
(
q
,
kv_c
,
kv_c
,
query_rope
=
q_rope
,
key_rope
=
k_pe
,
num_heads
=
layer
.
tp_q_head_num
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSND"
,
atten_mask
=
None
,
sparse_mode
=
0
,
scale
=
layer
.
scaling
,
antiquant_mode
=
0
,
antiquant_scale
=
None
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
)
else
:
assert
(
self
.
graph_mode
==
False
)
# _npu_paged_attention_mla not support graph mode
q
=
torch
.
cat
([
q
,
q_rope
],
dim
=-
1
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
kv_c_and_k_pe_cache
=
torch
.
cat
([
kv_c
,
k_pe
],
dim
=-
1
)
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
)
attn_output
=
torch
.
empty
(
[
num_tokens
,
layer
.
tp_q_head_num
,
self
.
kv_lora_rank
],
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
torch_npu
.
_npu_paged_attention_mla
(
query
=
query
,
key_cache
=
kv_c_and_k_pe_cache
,
num_kv_heads
=
layer
.
tp_k_head_num
,
num_kv_heads
=
layer
.
tp_k_head_num
,
num_heads
=
layer
.
tp_q_head_num
,
scale_value
=
layer
.
scaling
,
scale_value
=
layer
.
scaling
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_table
=
self
.
forward_metadata
.
block_tables
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
out
=
output
,
mla_vheadsize
=
self
.
kv_lora_rank
,
out
=
attn_output
,
)
)
return
output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
num_tokens
=
query
.
shape
[
0
]
kv_c_and_k_pe_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
)
attn_output
=
torch
.
empty
(
[
num_tokens
,
layer
.
tp_q_head_num
,
self
.
kv_lora_rank
],
dtype
=
q
.
dtype
,
device
=
q
.
device
,
)
torch_npu
.
_npu_paged_attention_mla
(
query
=
query
,
key_cache
=
kv_c_and_k_pe_cache
,
num_kv_heads
=
layer
.
tp_k_head_num
,
num_heads
=
layer
.
tp_q_head_num
,
scale_value
=
layer
.
scaling
,
block_table
=
self
.
forward_metadata
.
block_tables
,
context_lens
=
self
.
forward_metadata
.
seq_lens_cpu_int
,
mla_vheadsize
=
self
.
kv_lora_rank
,
out
=
attn_output
,
)
return
attn_output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
self
.
kv_lora_rank
)
return
attn_output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
self
.
kv_lora_rank
)
python/sglang/srt/layers/moe/topk.py
View file @
f92b729d
...
@@ -304,7 +304,7 @@ class TopK(CustomOp):
...
@@ -304,7 +304,7 @@ class TopK(CustomOp):
global_num_experts
=
router_logits
.
shape
[
-
1
]
global_num_experts
=
router_logits
.
shape
[
-
1
]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if
global_num_experts
==
256
and
self
.
topk_config
.
renormalize
is
Fals
e
:
if
global_num_experts
==
256
and
self
.
topk_config
.
renormalize
is
Tru
e
:
routed_scaling_factor
=
self
.
topk_config
.
routed_scaling_factor
or
1
routed_scaling_factor
=
self
.
topk_config
.
routed_scaling_factor
or
1
router_logits
=
router_logits
.
to
(
torch
.
float32
)
router_logits
=
router_logits
.
to
(
torch
.
float32
)
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
f92b729d
...
@@ -36,12 +36,15 @@ import triton.language as tl
...
@@ -36,12 +36,15 @@ import triton.language as tl
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
is_npu
,
next_power_of_2
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
GB
=
1024
*
1024
*
1024
GB
=
1024
*
1024
*
1024
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
if
_is_npu
:
import
torch_npu
class
ReqToTokenPool
:
class
ReqToTokenPool
:
...
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
...
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
cache_k
=
cache_k
.
view
(
self
.
store_dtype
)
cache_k
=
cache_k
.
view
(
self
.
store_dtype
)
cache_v
=
cache_v
.
view
(
self
.
store_dtype
)
cache_v
=
cache_v
.
view
(
self
.
store_dtype
)
import
torch_npu
torch_npu
.
_npu_reshape_and_cache
(
torch_npu
.
_npu_reshape_and_cache
(
key
=
cache_k
,
key
=
cache_k
,
value
=
cache_v
,
value
=
cache_v
,
...
@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
k
v
_buffer
=
torch
.
zeros
(
self
.
k_buffer
=
torch
.
zeros
(
(
(
layer_num
,
layer_num
,
self
.
size
//
self
.
page_size
+
1
,
self
.
size
//
self
.
page_size
+
1
,
self
.
page_size
,
self
.
page_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
kv_lora_rank
,
),
dtype
=
self
.
store_dtype
,
device
=
self
.
device
,
)
self
.
v_buffer
=
torch
.
zeros
(
(
layer_num
,
self
.
size
//
self
.
page_size
+
1
,
self
.
page_size
,
self
.
qk_rope_head_dim
,
),
),
dtype
=
self
.
store_dtype
,
dtype
=
self
.
store_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
)
)
self
.
mem_usage
=
kv_size
/
GB
self
.
mem_usage
=
kv_size
/
GB
def
get_kv_size_bytes
(
self
):
assert
hasattr
(
self
,
"k_buffer"
)
assert
hasattr
(
self
,
"v_buffer"
)
kv_size_bytes
=
0
for
k_cache
in
self
.
k_buffer
:
kv_size_bytes
+=
np
.
prod
(
k_cache
.
shape
)
*
k_cache
.
dtype
.
itemsize
for
v_cache
in
self
.
v_buffer
:
kv_size_bytes
+=
np
.
prod
(
v_cache
.
shape
)
*
v_cache
.
dtype
.
itemsize
return
kv_size_bytes
def
get_kv_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
return
(
self
.
k_buffer
[
layer_id
-
self
.
start_layer
],
self
.
v_buffer
[
layer_id
-
self
.
start_layer
],
)
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
-
self
.
start_layer
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
-
self
.
start_layer
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
-
self
.
start_layer
]
# for disagg
# for disagg
def
get_contiguous_buf_infos
(
self
):
def
get_contiguous_buf_infos
(
self
):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs
=
[
self
.
kv_buffer
[
i
].
data_ptr
()
for
i
in
range
(
self
.
layer_num
)]
kv_data_ptrs
=
[
self
.
k_buffer
[
i
].
data_ptr
()
for
i
in
range
(
self
.
layer_num
)]
+
[
kv_data_lens
=
[
self
.
kv_buffer
[
i
].
nbytes
for
i
in
range
(
self
.
layer_num
)]
self
.
v_buffer
[
i
].
data_ptr
()
for
i
in
range
(
self
.
layer_num
)
kv_item_lens
=
[
self
.
kv_buffer
[
i
][
0
].
nbytes
for
i
in
range
(
self
.
layer_num
)]
]
kv_data_lens
=
[
self
.
k_buffer
[
i
].
nbytes
for
i
in
range
(
self
.
layer_num
)]
+
[
self
.
v_buffer
[
i
].
nbytes
for
i
in
range
(
self
.
layer_num
)
]
kv_item_lens
=
[
self
.
k_buffer
[
i
][
0
].
nbytes
for
i
in
range
(
self
.
layer_num
)]
+
[
self
.
v_buffer
[
i
][
0
].
nbytes
for
i
in
range
(
self
.
layer_num
)
]
return
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
return
kv_data_ptrs
,
kv_data_lens
,
kv_item_lens
def
set_kv_buffer
(
def
set_kv_buffer
(
...
@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
view
(
self
.
store_dtype
)
cache_k
=
cache_k
.
view
(
self
.
store_dtype
)
import
torch_npu
if
cache_v
is
None
:
cache_k
,
cache_v
=
cache_k
.
split
(
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
torch_npu
.
_npu_reshape_and_cache_siso
(
torch_npu
.
npu_scatter_nd_update_
(
key
=
cache_k
.
view
(
-
1
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
),
self
.
k_buffer
[
layer_id
-
self
.
start_layer
].
view
(
-
1
,
1
,
self
.
kv_lora_rank
),
key_cache
=
self
.
kv_buffer
[
layer_id
-
self
.
start_layer
].
view
(
loc
.
view
(
-
1
,
1
),
-
1
,
1
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
cache_k
.
view
(
-
1
,
1
,
self
.
kv_lora_rank
),
)
torch_npu
.
npu_scatter_nd_update_
(
self
.
v_buffer
[
layer_id
-
self
.
start_layer
].
view
(
-
1
,
1
,
self
.
qk_rope_head_dim
),
),
slot_indices
=
loc
,
loc
.
view
(
-
1
,
1
),
cache_v
.
view
(
-
1
,
1
,
self
.
qk_rope_head_dim
),
)
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
f92b729d
...
@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -994,7 +994,14 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
current_attention_backend
=
attention_backend
self
.
current_attention_backend
=
attention_backend
if
attention_backend
==
"ascend"
:
if
attention_backend
==
"ascend"
:
return
AttnForwardMethod
.
MLA
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
):
return
AttnForwardMethod
.
MHA
else
:
return
AttnForwardMethod
.
MLA
elif
(
elif
(
attention_backend
==
"flashinfer"
attention_backend
==
"flashinfer"
or
attention_backend
==
"fa3"
or
attention_backend
==
"fa3"
...
@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1292,6 +1299,7 @@ class DeepseekV2AttentionMLA(nn.Module):
or
self
.
current_attention_backend
==
"flashinfer"
or
self
.
current_attention_backend
==
"flashinfer"
or
self
.
current_attention_backend
==
"cutlass_mla"
or
self
.
current_attention_backend
==
"cutlass_mla"
or
self
.
current_attention_backend
==
"trtllm_mla"
or
self
.
current_attention_backend
==
"trtllm_mla"
or
self
.
current_attention_backend
==
"ascend"
):
):
extra_args
=
{}
extra_args
=
{}
if
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
if
self
.
_fuse_rope_for_trtllm_mla
(
forward_batch
):
...
...
test/srt/ascend/test_ascend_mla_fia_w8a8int8.py
0 → 100644
View file @
f92b729d
import
os
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"/root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V2-Lite-W8A8"
:
{
"accuracy"
:
0.34
,
"latency"
:
1000
,
"output_throughput"
:
6
,
},
}
class
TestAscendMlaW8A8Int8
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--disable-cuda-graph"
,
"--mem-fraction-static"
,
0.8
,
"--attention-backend"
,
"ascend"
,
"--quantization"
,
"w8a8_int8"
,
"--tp-size"
,
2
,
"--disable-radix-cache"
,
]
def
test_a_gsm8k
(
self
):
os
.
environ
[
"ASCEND_USE_FIA"
]
=
"true"
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/ascend/test_ascend_mla_w8a8int8.py
View file @
f92b729d
...
@@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase):
...
@@ -40,6 +40,7 @@ class TestAscendMlaW8A8Int8(CustomTestCase):
"w8a8_int8"
,
"w8a8_int8"
,
"--tp-size"
,
"--tp-size"
,
4
,
4
,
"--disable-radix-cache"
,
]
]
def
test_a_gsm8k
(
self
):
def
test_a_gsm8k
(
self
):
...
...
test/srt/ascend/test_ascend_tp2_fia_bf16.py
0 → 100644
View file @
f92b729d
import
os
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
as
run_eval_few_shot_gsm8k
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
run_bench_offline_throughput
,
)
TEST_MODEL_MATRIX
=
{
"Qwen/Qwen2.5-7B-Instruct"
:
{
"accuracy"
:
0.85
,
"latency"
:
180
,
"output_throughput"
:
20
,
},
}
class
TestAscendTp2Bf16
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
models
=
TEST_MODEL_MATRIX
.
keys
()
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
url
=
urlparse
(
DEFAULT_URL_FOR_TEST
)
cls
.
common_args
=
[
"--trust-remote-code"
,
"--disable-cuda-graph"
,
"--mem-fraction-static"
,
0.8
,
"--attention-backend"
,
"ascend"
,
"--tp-size"
,
2
,
"--disable-radix-cache"
,
]
def
test_a_gsm8k
(
self
):
os
.
environ
[
"ASCEND_USE_FIA"
]
=
"true"
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing accuracy:
{
model
}
===##"
)
process
=
popen_launch_server
(
model
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
*
self
.
common_args
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
1319
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
self
.
url
.
hostname
}
"
,
port
=
int
(
self
.
url
.
port
),
)
metrics
=
run_eval_few_shot_gsm8k
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
TEST_MODEL_MATRIX
[
model
][
"accuracy"
],
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_b_throughput
(
self
):
for
model
in
self
.
models
:
with
self
.
subTest
(
model
=
model
):
print
(
f
"##=== Testing throughput:
{
model
}
===##"
)
output_throughput
=
run_bench_offline_throughput
(
model
,
[
*
self
.
common_args
,
],
)
print
(
f
"##===
{
model
}
throughput:
{
output_throughput
}
===##"
)
if
is_in_ci
():
self
.
assertGreater
(
output_throughput
,
TEST_MODEL_MATRIX
[
model
][
"output_throughput"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
f92b729d
...
@@ -275,6 +275,8 @@ suite_ascend = {
...
@@ -275,6 +275,8 @@ suite_ascend = {
"per-commit-2-ascend-npu"
:
[
"per-commit-2-ascend-npu"
:
[
TestFile
(
"ascend/test_ascend_tp2_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_tp2_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_graph_tp2_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_graph_tp2_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_tp2_fia_bf16.py"
,
400
),
TestFile
(
"ascend/test_ascend_mla_fia_w8a8int8.py"
,
400
),
],
],
"per-commit-4-ascend-npu"
:
[
"per-commit-4-ascend-npu"
:
[
TestFile
(
"ascend/test_ascend_mla_w8a8int8.py"
,
400
),
TestFile
(
"ascend/test_ascend_mla_w8a8int8.py"
,
400
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment