Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
095093ee
Unverified
Commit
095093ee
authored
Sep 23, 2025
by
ronnie_zheng
Committed by
GitHub
Sep 22, 2025
Browse files
[Ascend] optimize Qwen-vl on Ascend (#10556)
Co-authored-by:
wangqihui01
<
wangqh10@163.com
>
parent
d27a6f70
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
99 additions
and
14 deletions
+99
-14
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+58
-0
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+28
-11
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+4
-1
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+8
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/sglang/srt/layers/attention/vision.py
View file @
095093ee
...
@@ -16,14 +16,19 @@ from sglang.srt.utils import (
...
@@ -16,14 +16,19 @@ from sglang.srt.utils import (
get_device_capability
,
get_device_capability
,
is_blackwell
,
is_blackwell
,
is_cuda
,
is_cuda
,
is_npu
,
print_info_once
,
print_info_once
,
)
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
if
_is_npu
:
import
torch_npu
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
...
@@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module):
...
@@ -331,10 +336,63 @@ class VisionFlash3Attention(nn.Module):
return
output
return
output
class
VisionAscendAttention
(
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
,
):
if
not
_is_npu
:
raise
Exception
(
"VisionAscendAttention is only available for ascend npu"
)
super
().
__init__
()
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
Union
[
SingletonCache
,
torch
.
Tensor
]],
bsz
:
int
,
seq_len
:
int
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""
if
cu_seqlens
is
None
:
cu_seqlens
=
_get_cu_seqlens_for_shape
(
bsz
,
seq_len
,
device
=
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
if
seq_lens
.
is_npu
:
# cu_seqlens must be on cpu because of operator restriction
seq_lens
=
seq_lens
.
to
(
"cpu"
)
_
,
num_heads
,
head_size
=
q
.
shape
num_kv_heads
=
k
.
shape
[
1
]
output
=
torch
.
empty_like
(
q
)
# operator requires pta version >= 2.5.1
torch_npu
.
_npu_flash_attention_unpad
(
query
=
q
,
key
=
k
,
value
=
v
,
seq_len
=
seq_lens
.
to
(
torch
.
int32
),
scale_value
=
head_size
**-
0.5
,
num_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
out
=
output
,
)
return
output
QKV_BACKEND_IMPL
=
{
QKV_BACKEND_IMPL
=
{
"triton_attn"
:
VisionTritonAttention
,
"triton_attn"
:
VisionTritonAttention
,
"sdpa"
:
VisionSdpaAttention
,
"sdpa"
:
VisionSdpaAttention
,
"fa3"
:
VisionFlash3Attention
,
"fa3"
:
VisionFlash3Attention
,
"ascend_attn"
:
VisionAscendAttention
,
}
}
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
095093ee
...
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
...
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
cpu_has_amx_support
,
get_bool_env_var
,
get_bool_env_var
,
get_compiler_backend
,
is_cpu
,
is_cpu
,
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
...
@@ -33,6 +34,9 @@ if _use_aiter:
...
@@ -33,6 +34,9 @@ if _use_aiter:
if
is_npu
():
if
is_npu
():
import
torch_npu
import
torch_npu
NPU_ROTARY_MUL_MAX_NUM_HEADS
=
1000
NPU_ROTARY_MUL_MAX_HEAD_SIZE
=
896
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
...
@@ -1035,7 +1039,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1035,7 +1039,7 @@ class MRotaryEmbedding(RotaryEmbedding):
f
"Corrected mrope_section:
{
self
.
mrope_section
}
(sum=
{
sum
(
self
.
mrope_section
)
}
)"
f
"Corrected mrope_section:
{
self
.
mrope_section
}
(sum=
{
sum
(
self
.
mrope_section
)
}
)"
)
)
@
torch
.
compile
(
dynamic
=
True
)
@
torch
.
compile
(
dynamic
=
True
,
backend
=
get_compiler_backend
()
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -1894,17 +1898,30 @@ def apply_rotary_pos_emb_npu(
...
@@ -1894,17 +1898,30 @@ def apply_rotary_pos_emb_npu(
sin
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
unsqueeze_dim
=
1
,
unsqueeze_dim
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
q
.
shape
[
1
]
!=
128
:
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
Args:
q: [num_tokens, num_heads, head_size]
k: [num_tokens, num_kv_heads, head_size]
cos: [num_tokens, head_size]
sin: [num_tokens, head_size]
"""
if
(
cos
.
dim
()
!=
2
or
q
.
dim
()
!=
3
or
q
.
shape
[
1
]
>=
NPU_ROTARY_MUL_MAX_NUM_HEADS
or
q
.
shape
[
2
]
>=
NPU_ROTARY_MUL_MAX_HEAD_SIZE
):
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
return
apply_rotary_pos_emb_native
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
)
return
apply_rotary_pos_emb_native
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
).
unsqueeze
(
0
)
cos
=
torch
.
transpose
(
cos
,
1
,
2
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
).
unsqueeze
(
0
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
q
=
q
.
unsqueeze
(
0
)
sin
=
torch
.
transpose
(
sin
,
1
,
2
)
k
=
k
.
unsqueeze
(
0
)
q
=
torch
.
transpose
(
q
,
1
,
2
)
q_embed
=
torch_npu
.
npu_rotary_mul
(
q
,
cos
,
sin
)
k
=
torch
.
transpose
(
k
,
1
,
2
)
k_embed
=
torch_npu
.
npu_rotary_mul
(
k
,
cos
,
sin
)
q_embed
,
k_embed
=
torch_npu
.
npu_apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q_embed
=
q_embed
.
squeeze
(
0
)
q_embed
=
torch
.
transpose
(
q_embed
,
1
,
2
)
k_embed
=
k_embed
.
squeeze
(
0
)
k_embed
=
torch
.
transpose
(
k_embed
,
1
,
2
)
return
q_embed
,
k_embed
return
q_embed
,
k_embed
...
...
python/sglang/srt/model_loader/loader.py
View file @
095093ee
...
@@ -206,7 +206,10 @@ def _initialize_model(
...
@@ -206,7 +206,10 @@ def _initialize_model(
if
_is_npu
:
if
_is_npu
:
packed_modules_mapping
.
update
(
packed_modules_mapping
.
update
(
{
{
"visual"
:
{
"qkv_proj"
:
[
"qkv"
]},
"visual"
:
{
"qkv_proj"
:
[
"qkv"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
},
"vision_model"
:
{
"vision_model"
:
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"proj"
:
[
"out_proj"
],
"proj"
:
[
"out_proj"
],
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
095093ee
...
@@ -234,7 +234,14 @@ class BaseMultimodalProcessor(ABC):
...
@@ -234,7 +234,14 @@ class BaseMultimodalProcessor(ABC):
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
not
self
.
server_args
.
disable_fast_image_processor
and
not
self
.
server_args
.
disable_fast_image_processor
):
):
kwargs
[
"device"
]
=
"cuda"
if
not
_is_npu
else
"npu"
if
not
_is_npu
:
kwargs
[
"device"
]
=
"cuda"
elif
processor
.
__class__
.
__name__
not
in
{
"Qwen2_5_VLProcessor"
,
"Qwen3VLProcessor"
,
}:
# Note: for qwen-vl, processor has some reshape issue because of dims restriction on Ascend.
kwargs
[
"device"
]
=
"npu"
result
=
processor
.
__call__
(
result
=
processor
.
__call__
(
text
=
[
input_text
],
text
=
[
input_text
],
padding
=
True
,
padding
=
True
,
...
...
python/sglang/srt/server_args.py
View file @
095093ee
...
@@ -1840,7 +1840,7 @@ class ServerArgs:
...
@@ -1840,7 +1840,7 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--mm-attention-backend"
,
"--mm-attention-backend"
,
type
=
str
,
type
=
str
,
choices
=
[
"sdpa"
,
"fa3"
,
"triton_attn"
],
choices
=
[
"sdpa"
,
"fa3"
,
"triton_attn"
,
"ascend_attn"
],
default
=
ServerArgs
.
mm_attention_backend
,
default
=
ServerArgs
.
mm_attention_backend
,
help
=
"Set multimodal attention backend."
,
help
=
"Set multimodal attention backend."
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment