Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
74dd4249
Unverified
Commit
74dd4249
authored
Aug 29, 2025
by
chenxu140
Committed by
GitHub
Aug 28, 2025
Browse files
[Feature] Support NPUGraph for DeepSeek on Ascend NPU (#9355)
Co-authored-by:
Even Zhou
<
even.y.zhou@outlook.com
>
parent
dc20c22f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
307 additions
and
105 deletions
+307
-105
python/sglang/srt/disaggregation/ascend/conn.py
python/sglang/srt/disaggregation/ascend/conn.py
+75
-0
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+183
-88
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+12
-6
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+12
-2
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+7
-3
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+4
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+14
-6
No files found.
python/sglang/srt/disaggregation/ascend/conn.py
View file @
74dd4249
import
concurrent.futures
import
logging
import
logging
from
typing
import
List
,
Tuple
import
numpy
as
np
import
numpy.typing
as
npt
from
sglang.srt.disaggregation.ascend.transfer_engine
import
AscendTransferEngine
from
sglang.srt.disaggregation.ascend.transfer_engine
import
AscendTransferEngine
from
sglang.srt.disaggregation.common.utils
import
group_concurrent_contiguous
from
sglang.srt.disaggregation.mooncake.conn
import
(
from
sglang.srt.disaggregation.mooncake.conn
import
(
MooncakeKVBootstrapServer
,
MooncakeKVBootstrapServer
,
MooncakeKVManager
,
MooncakeKVManager
,
...
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
...
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
self
.
kv_args
.
aux_data_ptrs
,
self
.
kv_args
.
aux_data_lens
self
.
kv_args
.
aux_data_ptrs
,
self
.
kv_args
.
aux_data_lens
)
)
def
send_kvcache
(
self
,
mooncake_session_id
:
str
,
prefill_kv_indices
:
npt
.
NDArray
[
np
.
int32
],
dst_kv_ptrs
:
list
[
int
],
dst_kv_indices
:
npt
.
NDArray
[
np
.
int32
],
executor
:
concurrent
.
futures
.
ThreadPoolExecutor
,
):
# Group by indices
prefill_kv_blocks
,
dst_kv_blocks
=
group_concurrent_contiguous
(
prefill_kv_indices
,
dst_kv_indices
)
num_layers
=
len
(
self
.
kv_args
.
kv_data_ptrs
)
layers_params
=
[
(
self
.
kv_args
.
kv_data_ptrs
[
layer_id
],
dst_kv_ptrs
[
layer_id
],
self
.
kv_args
.
kv_item_lens
[
layer_id
],
)
for
layer_id
in
range
(
num_layers
)
]
def
set_transfer_blocks
(
src_ptr
:
int
,
dst_ptr
:
int
,
item_len
:
int
)
->
List
[
Tuple
[
int
,
int
,
int
]]:
transfer_blocks
=
[]
for
prefill_index
,
decode_index
in
zip
(
prefill_kv_blocks
,
dst_kv_blocks
):
src_addr
=
src_ptr
+
int
(
prefill_index
[
0
])
*
item_len
dst_addr
=
dst_ptr
+
int
(
decode_index
[
0
])
*
item_len
length
=
item_len
*
len
(
prefill_index
)
transfer_blocks
.
append
((
src_addr
,
dst_addr
,
length
))
return
transfer_blocks
# Worker function for processing a single layer
def
process_layer
(
src_ptr
:
int
,
dst_ptr
:
int
,
item_len
:
int
)
->
int
:
transfer_blocks
=
set_transfer_blocks
(
src_ptr
,
dst_ptr
,
item_len
)
return
self
.
_transfer_data
(
mooncake_session_id
,
transfer_blocks
)
# Worker function for processing all layers in a batch
def
process_layers
(
layers_params
:
List
[
Tuple
[
int
,
int
,
int
]])
->
int
:
transfer_blocks
=
[]
for
src_ptr
,
dst_ptr
,
item_len
in
layers_params
:
transfer_blocks
.
extend
(
set_transfer_blocks
(
src_ptr
,
dst_ptr
,
item_len
))
return
self
.
_transfer_data
(
mooncake_session_id
,
transfer_blocks
)
if
self
.
enable_custom_mem_pool
:
futures
=
[
executor
.
submit
(
process_layer
,
src_ptr
,
dst_ptr
,
item_len
,
)
for
(
src_ptr
,
dst_ptr
,
item_len
)
in
layers_params
]
for
future
in
concurrent
.
futures
.
as_completed
(
futures
):
status
=
future
.
result
()
if
status
!=
0
:
for
f
in
futures
:
f
.
cancel
()
return
status
else
:
# Combining all layers' params in one batch transfer is more efficient
# compared to using multiple threads
return
process_layers
(
layers_params
)
return
0
class
AscendKVSender
(
MooncakeKVSender
):
class
AscendKVSender
(
MooncakeKVSender
):
pass
pass
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
74dd4249
...
@@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -158,7 +158,7 @@ class AscendAttnBackend(AttentionBackend):
self
.
graph_mode
=
True
self
.
graph_mode
=
True
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
0
def
forward_extend
(
def
forward_extend
(
self
,
self
,
...
@@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -167,7 +167,7 @@ class AscendAttnBackend(AttentionBackend):
v
,
v
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
:
bool
=
True
,
):
):
if
not
self
.
use_mla
:
if
not
self
.
use_mla
:
if
save_kv_cache
:
if
save_kv_cache
:
...
@@ -253,25 +253,30 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -253,25 +253,30 @@ class AscendAttnBackend(AttentionBackend):
return
attn_output
return
attn_output
def
forward_decode
(
def
forward_decode
_graph
(
self
,
self
,
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
False
,
save_kv_cache
:
bool
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
if
not
self
.
use_mla
:
if
save_kv_cache
:
if
save_kv_cache
:
if
self
.
use_mla
:
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
kv_lora_rank
)
k_rope
=
k_rope
.
view
(
-
1
,
layer
.
tp_k_head_num
,
self
.
qk_rope_head_dim
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
k_rope
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
num_tokens
=
q
.
shape
[
0
]
if
self
.
graph_mode
:
if
not
self
.
use_mla
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
layer
.
qk_head_dim
)
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
layer
.
qk_head_dim
)
...
@@ -279,8 +284,14 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -279,8 +284,14 @@ class AscendAttnBackend(AttentionBackend):
layer
.
layer_id
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
query
=
q
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
query
=
q
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
workspace
=
(
if
self
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
actual_seq_len_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
else
:
actual_seq_len_kv
=
(
self
.
forward_metadata
.
seq_lens_cpu_int
.
cpu
().
int
().
tolist
()
)
num_tokens
=
query
.
shape
[
0
]
workspace
=
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
query
,
query
,
k_cache
,
k_cache
,
v_cache
,
v_cache
,
...
@@ -290,10 +301,9 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -290,10 +301,9 @@ class AscendAttnBackend(AttentionBackend):
num_key_value_heads
=
layer
.
tp_k_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSH"
,
input_layout
=
"BSH"
,
scale
=
layer
.
scaling
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
actual_seq_lengths_kv
=
actual_seq_len_kv
,
)
)
)
attn_
output
=
torch
.
empty
(
output
=
torch
.
empty
(
(
num_tokens
,
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
(
num_tokens
,
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
dtype
=
q
.
dtype
,
dtype
=
q
.
dtype
,
device
=
q
.
device
,
device
=
q
.
device
,
...
@@ -309,15 +319,102 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -309,15 +319,102 @@ class AscendAttnBackend(AttentionBackend):
num_key_value_heads
=
layer
.
tp_k_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
input_layout
=
"BSH"
,
input_layout
=
"BSH"
,
scale
=
layer
.
scaling
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
,
actual_seq_lengths_kv
=
actual_seq_len_kv
,
workspace
=
workspace
,
workspace
=
workspace
,
out
=
[
attn_
output
,
softmax_lse
],
out
=
[
output
,
softmax_lse
],
)
)
return
output
.
view
(
num_tokens
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
else
:
else
:
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
c_kv
,
k_rope
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
k_rope_cache
=
k_rope
.
view
(
layer
.
layer_id
-
1
,
layer
.
tp_k_head_num
,
self
.
page_size
,
self
.
qk_rope_head_dim
)
c_kv_cache
=
c_kv
.
view
(
-
1
,
layer
.
tp_v_head_num
,
self
.
page_size
,
self
.
kv_lora_rank
)
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
1
,
self
.
kv_lora_rank
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
1
,
self
.
qk_rope_head_dim
)
if
self
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
actual_seq_len_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
else
:
actual_seq_len_kv
=
(
self
.
forward_metadata
.
seq_lens_cpu_int
.
cpu
().
int
().
tolist
()
)
workspace
=
torch_npu
.
_npu_fused_infer_attention_score_get_max_workspace
(
q_nope
,
c_kv_cache
,
c_kv_cache
,
query_rope
=
q_rope
,
key_rope
=
k_rope_cache
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
input_layout
=
"BNSD"
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
actual_seq_len_kv
,
antiquant_mode
=
0
,
antiquant_scale
=
None
,
sparse_mode
=
0
,
)
output
=
torch
.
zeros_like
(
q_nope
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
softmax_lse
=
torch
.
empty
(
1
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch_npu
.
npu_fused_infer_attention_score
.
out
(
q_nope
,
c_kv_cache
,
c_kv_cache
,
query_rope
=
q_rope
,
key_rope
=
k_rope_cache
,
num_heads
=
layer
.
tp_q_head_num
,
num_key_value_heads
=
layer
.
tp_k_head_num
,
block_table
=
self
.
forward_metadata
.
block_tables
,
block_size
=
self
.
page_size
,
input_layout
=
"BNSD"
,
scale
=
layer
.
scaling
,
actual_seq_lengths_kv
=
actual_seq_len_kv
,
antiquant_mode
=
0
,
antiquant_scale
=
None
,
sparse_mode
=
0
,
workspace
=
workspace
,
out
=
[
output
,
softmax_lse
],
)
return
output
.
view
(
-
1
,
layer
.
tp_q_head_num
*
self
.
kv_lora_rank
)
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
self
.
graph_mode
:
return
self
.
forward_decode_graph
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
=
q_rope
,
k_rope
=
k_rope
,
)
)
if
not
self
.
use_mla
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
num_tokens
=
q
.
shape
[
0
]
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
if
self
.
use_fia
:
if
self
.
use_fia
:
attn_output
,
_
=
torch
.
ops
.
npu
.
npu_fused_infer_attention_score
(
attn_output
,
_
=
torch
.
ops
.
npu
.
npu_fused_infer_attention_score
(
q
.
view
(
q
.
view
(
...
@@ -370,9 +467,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -370,9 +467,7 @@ class AscendAttnBackend(AttentionBackend):
kv_c
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
kv_c
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_pe
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
k_pe
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
)
if
(
self
.
graph_mode
or
self
.
use_fia
)
and
(
if
self
.
use_fia
and
(
layer
.
tp_q_head_num
//
layer
.
tp_k_head_num
)
>=
8
:
layer
.
tp_q_head_num
//
layer
.
tp_k_head_num
)
>=
8
:
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
"""layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
kv_c
=
kv_c
.
view
(
kv_c
=
kv_c
.
view
(
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
self
.
kv_lora_rank
-
1
,
self
.
page_size
,
layer
.
tp_k_head_num
*
self
.
kv_lora_rank
...
...
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
74dd4249
...
@@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE):
...
@@ -746,19 +746,25 @@ class DeepEPMoE(EPMoE):
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
x
=
[
hidden_states
],
x
=
[
hidden_states
],
weight
=
[
self
.
w13_weight
],
weight
=
[
self
.
w13_weight
],
scale
=
[
self
.
w13_weight_scale
.
to
(
output_dtype
)],
per_token_scale
=
[
pertoken_scale
],
split_item
=
2
,
split_item
=
2
,
group_list_type
=
group_list_type
,
group_list_type
=
group_list_type
,
group_type
=
0
,
group_type
=
0
,
group_list
=
seg_indptr
,
group_list
=
seg_indptr
,
output_dtype
=
output_dtype
,
output_dtype
=
torch
.
int32
,
)[
0
]
)[
0
]
# act_fn: swiglu
# act_fn: swiglu
hidden_states
=
torch_npu
.
npu_swiglu
(
hidden_states
)
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dequant_swiglu_quant
(
x
=
hidden_states
,
hidden_states
,
swiglu_out_scale
=
torch_npu
.
npu_dynamic_quant
(
hidden_states
)
weight_scale
=
self
.
w13_weight_scale
.
to
(
torch
.
float32
),
activation_scale
=
pertoken_scale
,
bias
=
None
,
quant_scale
=
None
,
quant_offset
=
None
,
group_index
=
seg_indptr
,
activate_left
=
True
,
quant_mode
=
1
,
)
# gmm2: down_proj
# gmm2: down_proj
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
hidden_states
=
torch_npu
.
npu_grouped_matmul
(
...
...
python/sglang/srt/layers/moe/topk.py
View file @
74dd4249
...
@@ -304,12 +304,12 @@ class TopK(CustomOp):
...
@@ -304,12 +304,12 @@ class TopK(CustomOp):
global_num_experts
=
router_logits
.
shape
[
-
1
]
global_num_experts
=
router_logits
.
shape
[
-
1
]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if
global_num_experts
==
256
and
self
.
topk_config
.
renormalize
is
True
:
if
global_num_experts
==
256
:
routed_scaling_factor
=
self
.
topk_config
.
routed_scaling_factor
or
1
routed_scaling_factor
=
self
.
topk_config
.
routed_scaling_factor
or
1
router_logits
=
router_logits
.
to
(
torch
.
float32
)
router_logits
=
router_logits
.
to
(
torch
.
float32
)
return
torch_npu
.
npu_moe_gating_top_k
(
topk_weights
,
topk_ids
,
_
=
torch_npu
.
npu_moe_gating_top_k
(
router_logits
,
router_logits
,
k
=
self
.
topk_config
.
top_k
,
k
=
self
.
topk_config
.
top_k
,
bias
=
self
.
topk_config
.
correction_bias
.
to
(
torch
.
float32
),
bias
=
self
.
topk_config
.
correction_bias
.
to
(
torch
.
float32
),
...
@@ -321,6 +321,16 @@ class TopK(CustomOp):
...
@@ -321,6 +321,16 @@ class TopK(CustomOp):
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
eps
=
float
(
1e-20
),
eps
=
float
(
1e-20
),
)
)
if
self
.
topk_config
.
renormalize
:
topk_weights_sum
=
(
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
if
self
.
topk_config
.
num_fused_shared_experts
==
0
else
topk_weights
[:,
:
-
1
].
sum
(
dim
=-
1
,
keepdim
=
True
)
)
topk_weights
=
topk_weights
/
topk_weights_sum
return
StandardTopKOutput
(
topk_weights
,
topk_ids
,
_
)
else
:
else
:
self
.
topk_config
.
torch_native
=
True
self
.
topk_config
.
torch_native
=
True
return
select_experts
(
return
select_experts
(
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
74dd4249
...
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
...
@@ -551,7 +551,7 @@ class NPU_W8A8LinearMethodImpl:
def
get_pertensor_param
(
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
def
get_pertensor_param
(
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
params_dict
=
{}
params_dict
=
{}
params_dict
[
"input_scale"
]
=
torch
.
empty
(
1
,
dtype
=
params_dtype
)
params_dict
[
"input_scale"
]
=
torch
.
empty
(
1
,
dtype
=
params_dtype
)
params_dict
[
"input_offset"
]
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
)
params_dict
[
"input_offset"
]
=
torch
.
empty
(
1
,
dtype
=
params_dtype
)
return
params_dict
return
params_dict
@
staticmethod
@
staticmethod
...
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
...
@@ -582,11 +582,11 @@ class NPU_W8A8LinearMethodImpl:
if
original_dtype
!=
torch
.
int8
:
if
original_dtype
!=
torch
.
int8
:
x
=
torch_npu
.
npu_quantize
(
x
=
torch_npu
.
npu_quantize
(
x
,
x
,
layer
.
aclnn_input_scale
,
layer
.
aclnn_input_scale
_reciprocal
,
layer
.
aclnn_input_offset
,
layer
.
aclnn_input_offset
,
torch
.
qint8
,
torch
.
qint8
,
-
1
,
-
1
,
Tru
e
,
Fals
e
,
)
)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in Attention TP>1 case)
# bias will not get added more than once in Attention TP>1 case)
...
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
...
@@ -608,6 +608,10 @@ class NPU_W8A8LinearMethodImpl:
layer
.
input_scale
.
data
.
repeat
(
expanding_factor
).
to
(
device
=
"npu"
),
layer
.
input_scale
.
data
.
repeat
(
expanding_factor
).
to
(
device
=
"npu"
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
layer
.
aclnn_input_scale_reciprocal
=
1
/
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
.
repeat
(
expanding_factor
).
to
(
device
=
"npu"
),
requires_grad
=
False
,
)
layer
.
aclnn_input_offset
=
torch
.
nn
.
Parameter
(
layer
.
aclnn_input_offset
=
torch
.
nn
.
Parameter
(
layer
.
input_offset
.
data
.
repeat
(
expanding_factor
).
to
(
device
=
"npu"
),
layer
.
input_offset
.
data
.
repeat
(
expanding_factor
).
to
(
device
=
"npu"
),
requires_grad
=
False
,
requires_grad
=
False
,
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
74dd4249
...
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
layer_num
,
layer_num
,
self
.
size
//
self
.
page_size
+
1
,
self
.
size
//
self
.
page_size
+
1
,
self
.
page_size
,
self
.
page_size
,
1
,
self
.
kv_lora_rank
,
self
.
kv_lora_rank
,
),
),
dtype
=
self
.
store_dtype
,
dtype
=
self
.
store_dtype
,
...
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
layer_num
,
layer_num
,
self
.
size
//
self
.
page_size
+
1
,
self
.
size
//
self
.
page_size
+
1
,
self
.
page_size
,
self
.
page_size
,
1
,
self
.
qk_rope_head_dim
,
self
.
qk_rope_head_dim
,
),
),
dtype
=
self
.
store_dtype
,
dtype
=
self
.
store_dtype
,
...
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
layer_id
=
layer
.
layer_id
layer_id
=
layer
.
layer_id
if
cache_k
.
dtype
!=
self
.
dtype
:
if
cache_k
.
dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_k
=
cache_k
.
to
(
self
.
dtype
)
cache_v
=
cache_v
.
to
(
self
.
dtype
)
if
self
.
store_dtype
!=
self
.
dtype
:
if
self
.
store_dtype
!=
self
.
dtype
:
cache_k
=
cache_k
.
view
(
self
.
store_dtype
)
cache_k
=
cache_k
.
view
(
self
.
store_dtype
)
cache_v
=
cache_v
.
view
(
self
.
store_dtype
)
if
cache_v
is
None
:
if
cache_v
is
None
:
cache_k
,
cache_v
=
cache_k
.
split
(
cache_k
,
cache_v
=
cache_k
.
split
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
74dd4249
...
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
...
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_non_idle_and_non_empty
,
is_non_idle_and_non_empty
,
is_npu
,
is_sm100_supported
,
is_sm100_supported
,
log_info_on_rank0
,
log_info_on_rank0
,
make_layers
,
make_layers
,
...
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
...
@@ -122,6 +123,7 @@ from sglang.srt.utils import (
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
...
@@ -1181,6 +1183,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1181,6 +1183,7 @@ class DeepseekV2AttentionMLA(nn.Module):
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
if
not
_is_npu
:
latent_cache
[:,
:,
:
self
.
kv_lora_rank
]
=
kv_a
.
unsqueeze
(
1
)
latent_cache
[:,
:,
:
self
.
kv_lora_rank
]
=
kv_a
.
unsqueeze
(
1
)
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
=
k_pe
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
=
k_pe
...
@@ -1188,6 +1191,11 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -1188,6 +1191,11 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
)
else
:
# To reduce a time-costing split operation
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
kv_a
.
unsqueeze
(
1
),
k_pe
)
return
q
,
k
,
v
,
forward_batch
return
q
,
k
,
v
,
forward_batch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment