Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d27a6f70
Unverified
Commit
d27a6f70
authored
Sep 23, 2025
by
Even Zhou
Committed by
GitHub
Sep 22, 2025
Browse files
[Feature] Add MLAProcess for DeepSeek MLA on NPU (#10130)
parent
0753ef83
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
369 additions
and
23 deletions
+369
-23
docs/platforms/ascend_npu.md
docs/platforms/ascend_npu.md
+6
-5
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+6
-1
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
+300
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+20
-14
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+32
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+4
-0
test/srt/ascend/test_ascend_deepep.py
test/srt/ascend/test_ascend_deepep.py
+1
-0
No files found.
docs/platforms/ascend_npu.md
View file @
d27a6f70
...
...
@@ -118,7 +118,7 @@ git clone https://github.com/sgl-project/sglang.git
cd
sglang/docker
# Build the docker image
docker build
-t
sglang-npu:main
-f
Dockerfile.npu
.
docker build
-t
<image_name>
-f
Dockerfile.npu
.
alias
drun
=
'docker run -it --rm --privileged --network=host --ipc=host --shm-size=16g \
--device=/dev/davinci0 --device=/dev/davinci1 --device=/dev/davinci2 --device=/dev/davinci3 \
...
...
@@ -132,7 +132,7 @@ alias drun='docker run -it --rm --privileged --network=host --ipc=host --shm-siz
--volume /var/queue_schedule:/var/queue_schedule --volume ~/.cache/:/root/.cache/'
drun
--env
"HF_TOKEN=<secret>"
\
sglang-npu:main
\
<image_name>
\
python3
-m
sglang.launch_server
--model-path
meta-llama/Llama-3.1-8B-Instruct
--attention-backend
ascend
--host
0.0.0.0
--port
30000
```
...
...
@@ -149,7 +149,7 @@ Prefill:
export
PYTORCH_NPU_ALLOC_CONF
=
expandable_segments:True
export
ASCEND_MF_STORE_URL
=
"tcp://<PREFILL_HOST_IP>:<PORT>"
drun
sglang-npu:main
\
drun
<image_name>
\
python3
-m
sglang.launch_server
--model-path
State_Cloud/DeepSeek-R1-bf16-hfd-w8a8
\
--trust-remote-code
\
--attention-backend
ascend
\
...
...
@@ -174,8 +174,9 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export
ASCEND_MF_STORE_URL
=
"tcp://<PREFILL_HOST_IP>:<PORT>"
export
HCCL_BUFFSIZE
=
200
export
SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK
=
24
export
SGLANG_NPU_USE_MLAPO
=
1
drun
sglang-npu:main
\
drun
<image_name>
\
python3
-m
sglang.launch_server
--model-path
State_Cloud/DeepSeek-R1-bf16-hfd-w8a8
\
--trust-remote-code
\
--attention-backend
ascend
\
...
...
@@ -198,7 +199,7 @@ drun sglang-npu:main \
Mini_LB:
```
shell
drun
sglang-npu:main
\
drun
<image_name>
\
python
-m
sglang.srt.disaggregation.launch_lb
\
--prefill
http://<PREFILL_HOST_IP>:8000
\
--decode
http://<DECODE_HOST_IP>:8001
\
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
d27a6f70
...
...
@@ -9,6 +9,7 @@ from torch.nn.functional import scaled_dot_product_attention
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.npu_ops.mla_preprocess
import
is_mla_preprocess_enabled
from
sglang.srt.layers.attention.torch_native_backend
import
TorchNativeAttnBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.radix_attention
import
AttentionType
...
...
@@ -401,7 +402,7 @@ class AscendAttnBackend(AttentionBackend):
antiquant_scale
=
None
,
sparse_mode
=
0
,
)
output
=
torch
.
zeros
_like
(
q_nope
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
output
=
torch
.
empty
_like
(
q_nope
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
softmax_lse
=
torch
.
empty
(
1
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
torch_npu
.
npu_fused_infer_attention_score
.
out
(
...
...
@@ -437,6 +438,10 @@ class AscendAttnBackend(AttentionBackend):
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
is_mla_preprocess_enabled
():
# MLAPO does saving kv_cache
save_kv_cache
=
False
if
self
.
graph_mode
:
return
self
.
forward_decode_graph
(
q
,
...
...
python/sglang/srt/layers/attention/npu_ops/mla_preprocess.py
0 → 100644
View file @
d27a6f70
import
torch
import
torch.nn.functional
as
F
from
sglang.srt.utils
import
get_bool_env_var
,
is_npu
_is_npu
=
is_npu
()
_ENABLE_MLA_PREPROCESS_FLAG
=
get_bool_env_var
(
"SGLANG_NPU_USE_MLAPO"
)
_NPU_FORMAT_NZ
=
29
def
is_mla_preprocess_enabled
()
->
bool
:
return
_is_npu
and
_ENABLE_MLA_PREPROCESS_FLAG
if
is_mla_preprocess_enabled
():
import
sgl_kernel_npu
import
torch_npu
torch
.
npu
.
config
.
allow_internal_format
=
True
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
def
round_up
(
val
:
int
,
align
:
int
)
->
int
:
if
align
==
0
:
return
0
return
-
(
val
//
-
align
)
*
align
def
transdata
(
nd_mat
,
block_size
:
tuple
=
(
16
,
16
)):
r
=
round_up
(
nd_mat
.
shape
[
0
],
block_size
[
0
])
c
=
round_up
(
nd_mat
.
shape
[
1
],
block_size
[
1
])
r_pad
=
r
-
nd_mat
.
shape
[
0
]
c_pad
=
c
-
nd_mat
.
shape
[
1
]
nd_mat
=
F
.
pad
(
nd_mat
,
((
0
,
r_pad
,
0
,
c_pad
)))
nz_mat
=
torch
.
permute
(
torch
.
reshape
(
nd_mat
,
(
r
//
block_size
[
0
],
block_size
[
0
],
c
//
block_size
[
1
],
block_size
[
1
]),
),
[
2
,
0
,
1
,
3
],
)
nz_mat
=
torch
.
reshape
(
nz_mat
,
(
nz_mat
.
shape
[
0
],
nz_mat
.
shape
[
1
]
*
nz_mat
.
shape
[
2
],
nz_mat
.
shape
[
3
])
)
return
nz_mat
def
trans_rope_weight
(
weight
,
rope_dim
):
weight_1
=
weight
[...,
-
rope_dim
::
2
,
:].
contiguous
()
weight_2
=
weight
[...,
-
rope_dim
+
1
::
2
,
:].
contiguous
()
weight
[...,
-
rope_dim
:,
:]
=
torch
.
cat
([
weight_1
,
weight_2
],
dim
=-
2
)
return
weight
.
contiguous
()
class
NPUFusedMLAPreprocess
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
fused_qkv_a_proj_with_mqa
,
q_a_layernorm
,
kv_a_layernorm
,
q_b_proj
,
w_kc
,
rotary_emb
,
layer_id
,
num_local_heads
,
qk_nope_head_dim
,
qk_rope_head_dim
,
):
super
().
__init__
()
self
.
qkv_a_proj
=
fused_qkv_a_proj_with_mqa
self
.
q_a_layernorm
=
q_a_layernorm
self
.
kv_a_layernorm
=
kv_a_layernorm
self
.
q_b_proj
=
q_b_proj
self
.
w_kc
=
w_kc
.
contiguous
()
self
.
rotary_emb
=
rotary_emb
self
.
layer_id
=
layer_id
self
.
has_preprocess_weights
=
False
self
.
q_lora_rank
=
self
.
q_b_proj
.
input_size
# 1536
self
.
kv_lora_rank
=
self
.
kv_a_layernorm
.
hidden_size
# 512
self
.
num_local_heads
=
num_local_heads
# tp
self
.
qk_nope_head_dim
=
qk_nope_head_dim
# 128
self
.
qk_rope_head_dim
=
qk_rope_head_dim
# 64
def
preprocess_weights
(
self
,
hidden_states
):
self
.
dummy
=
torch
.
empty
(
(
hidden_states
.
shape
[
-
1
]),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
self
.
qkv_a_proj_input_offset
=
self
.
qkv_a_proj
.
input_offset
.
to
(
dtype
=
torch
.
int8
)
self
.
q_b_proj_input_offset
=
self
.
q_b_proj
.
input_offset
.
to
(
dtype
=
torch
.
int8
)
# matmul_0 weight [7168, 2112]
fused_qkv_a_proj_with_mqa_weight_q
=
self
.
qkv_a_proj
.
weight
.
data
[
:,
:
self
.
q_lora_rank
].
clone
()
# [7168, 1536]
fused_qkv_a_proj_with_mqa_weight_kv
=
self
.
qkv_a_proj
.
weight
.
data
[
:,
self
.
q_lora_rank
:
].
clone
()
# [7168, 576]
# rope fit
fused_qkv_a_proj_with_mqa_weight_kv_t
=
(
fused_qkv_a_proj_with_mqa_weight_kv
.
t
().
contiguous
()
)
fused_qkv_a_proj_with_mqa_weight_kv_t
=
trans_rope_weight
(
fused_qkv_a_proj_with_mqa_weight_kv_t
,
self
.
qk_rope_head_dim
)
fused_qkv_a_proj_with_mqa_weight_kv
=
(
fused_qkv_a_proj_with_mqa_weight_kv_t
.
t
().
contiguous
()
)
# cat nz
fused_qkv_a_proj_with_mqa_weight_new
=
torch
.
cat
(
(
fused_qkv_a_proj_with_mqa_weight_kv
,
fused_qkv_a_proj_with_mqa_weight_q
),
dim
=-
1
,
)
fused_qkv_a_proj_with_mqa_weight
=
(
fused_qkv_a_proj_with_mqa_weight_new
.
t
().
contiguous
()
)
fused_qkv_a_proj_with_mqa_weight_nz
=
(
transdata
(
fused_qkv_a_proj_with_mqa_weight
,
block_size
=
(
16
,
32
))
.
unsqueeze
(
0
)
.
contiguous
()
)
self
.
qkv_a_proj_weight_nz
=
torch_npu
.
npu_format_cast
(
fused_qkv_a_proj_with_mqa_weight_nz
,
_NPU_FORMAT_NZ
)
# matmul_0 deq_scale [2112]
fused_qkv_a_proj_with_mqa_deq_scale_q
=
self
.
qkv_a_proj
.
deq_scale
.
data
[
:
self
.
q_lora_rank
].
clone
()
# [7168, 1536]
fused_qkv_a_proj_with_mqa_deq_scale_kv
=
self
.
qkv_a_proj
.
deq_scale
.
data
[
self
.
q_lora_rank
:
].
clone
()
# [7168, 576]
# rope fit
fused_qkv_a_proj_with_mqa_deq_scale_kv
=
(
fused_qkv_a_proj_with_mqa_deq_scale_kv
.
reshape
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
-
1
).
contiguous
()
)
fused_qkv_a_proj_with_mqa_deq_scale_kv
=
trans_rope_weight
(
fused_qkv_a_proj_with_mqa_deq_scale_kv
,
self
.
qk_rope_head_dim
)
fused_qkv_a_proj_with_mqa_deq_scale_kv
=
(
fused_qkv_a_proj_with_mqa_deq_scale_kv
.
view
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
).
contiguous
()
)
self
.
qkv_a_proj_deq_scale_kvq
=
torch
.
cat
(
(
fused_qkv_a_proj_with_mqa_deq_scale_kv
,
fused_qkv_a_proj_with_mqa_deq_scale_q
,
),
dim
=-
1
,
)
# matmul_0 quant_bias [2112]
fused_qkv_a_proj_with_mqa_quant_bias_q
=
self
.
qkv_a_proj
.
quant_bias
.
data
[
:
self
.
q_lora_rank
].
clone
()
# [7168, 1536]
fused_qkv_a_proj_with_mqa_quant_bias_kv
=
self
.
qkv_a_proj
.
quant_bias
.
data
[
self
.
q_lora_rank
:
].
clone
()
# [7168, 576]
# rope fit
fused_qkv_a_proj_with_mqa_quant_bias_kv
=
(
fused_qkv_a_proj_with_mqa_quant_bias_kv
.
reshape
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
-
1
).
contiguous
()
)
fused_qkv_a_proj_with_mqa_quant_bias_kv
=
trans_rope_weight
(
fused_qkv_a_proj_with_mqa_quant_bias_kv
,
self
.
qk_rope_head_dim
)
fused_qkv_a_proj_with_mqa_quant_bias_kv
=
(
fused_qkv_a_proj_with_mqa_quant_bias_kv
.
view
(
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
).
contiguous
()
)
self
.
qkv_a_proj_quant_bias_kvq
=
torch
.
cat
(
(
fused_qkv_a_proj_with_mqa_quant_bias_kv
,
fused_qkv_a_proj_with_mqa_quant_bias_q
,
),
dim
=-
1
,
)
# matmul_1 weight [1536, num_head * 192]
q_b_proj_weight
=
self
.
q_b_proj
.
weight
.
data
.
clone
()
q_b_proj_weight
=
q_b_proj_weight
.
t
().
reshape
(
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
-
1
)
q_b_proj_weight
=
trans_rope_weight
(
q_b_proj_weight
,
self
.
qk_rope_head_dim
)
q_b_proj_weight
=
q_b_proj_weight
.
reshape
(
self
.
num_local_heads
*
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
),
-
1
)
q_b_proj_weight_nz
=
(
transdata
(
q_b_proj_weight
,
block_size
=
(
16
,
32
)).
unsqueeze
(
0
).
contiguous
()
)
self
.
q_b_proj_weight_nz
=
torch_npu
.
npu_format_cast
(
q_b_proj_weight_nz
,
_NPU_FORMAT_NZ
)
# matmul_1 deq_scale [num_head * 192]
q_b_proj_deq_scale
=
self
.
q_b_proj
.
deq_scale
.
data
.
clone
()
q_b_proj_deq_scale
=
q_b_proj_deq_scale
.
reshape
(
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
-
1
)
q_b_proj_deq_scale
=
trans_rope_weight
(
q_b_proj_deq_scale
,
self
.
qk_rope_head_dim
)
self
.
q_b_proj_deq_scale
=
q_b_proj_deq_scale
.
reshape
(
self
.
num_local_heads
*
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
)
)
# matmul_1 quant_bias [num_head * 192]
q_b_proj_quant_bias
=
self
.
q_b_proj
.
quant_bias
.
data
.
clone
()
q_b_proj_quant_bias
=
q_b_proj_quant_bias
.
reshape
(
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
-
1
)
q_b_proj_quant_bias
=
trans_rope_weight
(
q_b_proj_quant_bias
,
self
.
qk_rope_head_dim
)
self
.
q_b_proj_quant_bias
=
q_b_proj_quant_bias
.
reshape
(
self
.
num_local_heads
*
(
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
)
)
def
get_sin_cos
(
self
,
positions
):
cos_sin
=
self
.
rotary_emb
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
)
sin
=
sin
.
repeat
(
1
,
2
)
return
cos
,
sin
def
get_kv_cache_and_cache_idx
(
self
,
forward_batch
):
k_cache
,
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
)
slot_mapping
=
forward_batch
.
out_cache_loc
.
to
(
dtype
=
torch
.
int32
)
return
k_cache
,
v_cache
,
slot_mapping
def
forward
(
self
,
positions
,
hidden_states
,
forward_batch
,
zero_allocator
):
input_dtype
=
hidden_states
.
dtype
if
not
self
.
has_preprocess_weights
:
self
.
preprocess_weights
(
hidden_states
)
self
.
has_preprocess_weights
=
True
self
.
dtype
=
hidden_states
.
dtype
cos
,
sin
=
self
.
get_sin_cos
(
positions
)
k_cache
,
v_cache
,
slot_mapping
=
self
.
get_kv_cache_and_cache_idx
(
forward_batch
)
q_nope_out
=
torch
.
empty
(
(
hidden_states
.
shape
[
0
],
self
.
w_kc
.
shape
[
0
],
k_cache
.
shape
[
-
1
]),
dtype
=
input_dtype
,
device
=
hidden_states
.
device
,
)
q_rope_out
=
torch
.
empty
(
(
hidden_states
.
shape
[
0
],
self
.
w_kc
.
shape
[
0
],
v_cache
.
shape
[
-
1
]),
dtype
=
input_dtype
,
device
=
hidden_states
.
device
,
)
# TODO: dummy inputs to be removed
# https://github.com/sgl-project/sgl-kernel-npu/issues/78
torch
.
ops
.
npu
.
mla_preprocess
(
hidden_states
,
self
.
dummy
,
self
.
dummy
,
self
.
qkv_a_proj_weight_nz
,
self
.
qkv_a_proj_deq_scale_kvq
,
self
.
q_a_layernorm
.
weight
,
self
.
q_a_layernorm
.
bias
,
self
.
q_b_proj_weight_nz
,
self
.
q_b_proj_deq_scale
,
self
.
kv_a_layernorm
.
weight
,
cos
,
sin
,
self
.
w_kc
,
k_cache
,
v_cache
,
slot_mapping
,
quant_scale0
=
self
.
qkv_a_proj
.
input_scale
,
quant_offset0
=
self
.
qkv_a_proj_input_offset
,
bias0
=
self
.
qkv_a_proj_quant_bias_kvq
,
quant_scale1
=
self
.
q_b_proj
.
input_scale
,
quant_offset1
=
self
.
q_b_proj_input_offset
,
bias1
=
self
.
q_b_proj_quant_bias
,
cache_mode
=
"krope_ctkv"
,
quant_mode
=
"per_tensor_quant_asymm"
,
q_out0
=
q_nope_out
,
kv_cache_out0
=
k_cache
,
q_out1
=
q_rope_out
,
kv_cache_out1
=
v_cache
,
)
return
(
q_rope_out
,
v_cache
,
q_nope_out
,
k_cache
,
forward_batch
,
zero_allocator
,
positions
,
)
python/sglang/srt/layers/rotary_embedding.py
View file @
d27a6f70
...
...
@@ -782,27 +782,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# NOTE: now npu_mrope can only support `numQHeads*headSize <= 4096` pattern,
# and generalization to more scenarios will be supported in the future.
if
query
.
shape
[
1
]
*
query
.
shape
[
2
]
>
4096
:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
num_tokens
=
query
.
shape
[
0
]
rotary_mode
=
"half"
if
self
.
is_neox_style
else
"interleave"
num_tokens
,
num_q_heads
,
_
=
query
.
shape
num_k_heads
=
key
.
shape
[
1
]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
# Reshape to [batchsize, head_dim, seq, rotary_dim]
cos
=
cos
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
).
unsqueeze
(
-
2
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
query_rot
,
key_rot
=
torch_npu
.
npu_mrope
(
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
,
query_rot
.
reshape
(
num_tokens
,
-
1
),
key_rot
.
reshape
(
num_tokens
,
-
1
),
self
.
cos_sin_cache
,
self
.
rotary_dim
,
mrope_section
=
[
0
,
0
,
0
],
rotary_mode
=
rotary_mode
,
query_rot
=
torch_npu
.
npu_interleave_rope
(
query_rot
.
reshape
(
num_tokens
,
num_q_heads
,
1
,
self
.
rotary_dim
),
cos
,
sin
,
)
key_rot
=
torch_npu
.
npu_interleave_rope
(
key_rot
.
reshape
(
num_tokens
,
num_k_heads
,
1
,
self
.
rotary_dim
),
cos
,
sin
,
)
query_rot
=
query_rot
.
reshape
(
num_tokens
,
-
1
,
self
.
rotary_dim
)
key_rot
=
key_rot
.
reshape
(
num_tokens
,
-
1
,
self
.
rotary_dim
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
d27a6f70
...
...
@@ -43,6 +43,10 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from
sglang.srt.eplb.expert_location_dispatch
import
ExpertLocationDispatchInfo
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.amx_utils
import
PackWeightMethod
from
sglang.srt.layers.attention.npu_ops.mla_preprocess
import
(
NPUFusedMLAPreprocess
,
is_mla_preprocess_enabled
,
)
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerScatterModes
,
...
...
@@ -1177,6 +1181,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self
.
weight_block_size
=
(
self
.
fused_qkv_a_proj_with_mqa
.
quant_method
.
quant_config
.
weight_block_size
)
self
.
is_mla_preprocess_enabled
=
is_mla_preprocess_enabled
()
if
self
.
is_mla_preprocess_enabled
:
assert
(
quant_config
.
get_name
()
==
"w8a8_int8"
),
"MLA Preprocess only works with W8A8Int8"
self
.
mla_preprocess
=
None
def
dispatch_attn_forward_method
(
self
,
forward_batch
:
ForwardBatch
...
...
@@ -1263,9 +1273,28 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA
:
inner_state
=
self
.
forward_absorb_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
if
not
self
.
is_mla_preprocess_enabled
:
inner_state
=
self
.
forward_absorb_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
else
:
# TODO(iforgetmyname): to be separated as a standalone func
if
self
.
mla_preprocess
is
None
:
self
.
mla_preprocess
=
NPUFusedMLAPreprocess
(
self
.
fused_qkv_a_proj_with_mqa
,
self
.
q_a_layernorm
,
self
.
kv_a_layernorm
,
self
.
q_b_proj
,
self
.
w_kc
,
self
.
rotary_emb
,
self
.
layer_id
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
,
)
inner_state
=
self
.
mla_preprocess
.
forward
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
)
elif
attn_forward_method
==
AttnForwardMethod
.
MLA_FUSED_ROPE
:
inner_state
=
self
.
forward_absorb_fused_mla_rope_prepare
(
positions
,
hidden_states
,
forward_batch
,
zero_allocator
...
...
python/sglang/srt/utils.py
View file @
d27a6f70
...
...
@@ -174,6 +174,8 @@ def is_blackwell():
@
lru_cache
(
maxsize
=
1
)
def
is_sm100_supported
(
device
=
None
)
->
bool
:
if
not
is_cuda_alike
():
return
False
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
10
)
and
(
torch
.
version
.
cuda
>=
"12.8"
)
...
...
@@ -181,6 +183,8 @@ def is_sm100_supported(device=None) -> bool:
@
lru_cache
(
maxsize
=
1
)
def
is_sm90_supported
(
device
=
None
)
->
bool
:
if
not
is_cuda_alike
():
return
False
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
9
)
and
(
torch
.
version
.
cuda
>=
"12.3"
)
...
...
test/srt/ascend/test_ascend_deepep.py
View file @
d27a6f70
...
...
@@ -60,6 +60,7 @@ class TestAscendDeepEP(CustomTestCase):
cls
.
extra_envs
=
{
"HCCL_BUFFSIZE"
:
"500"
,
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK"
:
"32"
,
"SGLANG_NPU_USE_MLAPO"
:
"1"
,
}
os
.
environ
.
update
(
cls
.
extra_envs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment