Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ba861293
Unverified
Commit
ba861293
authored
Aug 31, 2025
by
VDV1985
Committed by
GitHub
Aug 31, 2025
Browse files
[feat]Ascend NPU Gemma-3-12b and Gemma-3-27b support (#8909)
parent
c112bcc4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
136 additions
and
30 deletions
+136
-30
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+12
-0
python/sglang/srt/layers/attention/ascend_backend.py
python/sglang/srt/layers/attention/ascend_backend.py
+59
-23
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+28
-3
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+28
-1
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+5
-1
python/sglang/srt/multimodal/processors/base_processor.py
python/sglang/srt/multimodal/processors/base_processor.py
+4
-2
No files found.
python/sglang/srt/layers/activation.py
View file @
ba861293
...
@@ -103,6 +103,15 @@ class GeluAndMul(CustomOp):
...
@@ -103,6 +103,15 @@ class GeluAndMul(CustomOp):
raise
RuntimeError
(
"GeluAndMul only support tanh or none"
)
raise
RuntimeError
(
"GeluAndMul only support tanh or none"
)
return
out
return
out
def
forward_npu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
y_npu
,
gelu_npu
=
torch_npu
.
npu_geglu
(
x
,
dim
=-
1
,
approximate
=
1
if
self
.
approximate
==
"tanh"
else
0
,
activate_left
=
True
,
)
return
y_npu
class
NewGELU
(
CustomOp
):
class
NewGELU
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -137,6 +146,9 @@ class QuickGELU(CustomOp):
...
@@ -137,6 +146,9 @@ class QuickGELU(CustomOp):
gelu_quick
(
x
,
out
)
gelu_quick
(
x
,
out
)
return
out
return
out
def
forward_npu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch_npu
.
npu_fast_gelu
(
x
)
class
ScaledActivation
(
nn
.
Module
):
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
"""An activation function with post-scale parameters.
...
...
python/sglang/srt/layers/attention/ascend_backend.py
View file @
ba861293
...
@@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -180,7 +180,7 @@ class AscendAttnBackend(AttentionBackend):
if
self
.
use_fia
:
if
self
.
use_fia
:
"""FIA will support multi-bs in the later version of CANN"""
"""FIA will support multi-bs in the later version of CANN"""
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
attn_output
=
torch
.
empty
(
attn_output
=
torch
.
empty
(
(
q
.
size
(
0
),
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
(
q
.
size
(
0
),
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
device
=
q
.
device
,
device
=
q
.
device
,
...
@@ -208,7 +208,8 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -208,7 +208,8 @@ class AscendAttnBackend(AttentionBackend):
)
)
else
:
else
:
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
if
layer
.
qk_head_dim
<=
128
:
query
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
attn_output
=
torch
.
empty
(
attn_output
=
torch
.
empty
(
(
query
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
(
query
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
),
dtype
=
query
.
dtype
,
dtype
=
query
.
dtype
,
...
@@ -228,6 +229,40 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -228,6 +229,40 @@ class AscendAttnBackend(AttentionBackend):
num_kv_heads
=
layer
.
tp_k_head_num
,
num_kv_heads
=
layer
.
tp_k_head_num
,
out
=
attn_output
,
out
=
attn_output
,
)
)
else
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
attn_output
=
q
.
new_empty
(
(
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
)
else
:
attn_output
=
torch
.
empty_like
(
q
)
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
q_
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
o_
=
attn_output
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
causal
=
True
if
(
layer
.
is_cross_attention
or
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
):
causal
=
False
self
.
native_attn
.
_run_sdpa_forward_extend
(
q_
,
o_
,
k_cache
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
qk_head_dim
),
v_cache
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_prefix_lens
,
forward_batch
.
extend_seq_lens
,
scaling
=
layer
.
scaling
,
enable_gqa
=
use_gqa
,
causal
=
causal
,
)
else
:
else
:
assert
(
assert
(
layer
.
qk_head_dim
!=
layer
.
v_head_dim
layer
.
qk_head_dim
!=
layer
.
v_head_dim
...
@@ -283,7 +318,7 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -283,7 +318,7 @@ class AscendAttnBackend(AttentionBackend):
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
v_cache
=
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
layer
.
layer_id
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
).
view
(
-
1
,
self
.
page_size
,
layer
.
tp_v_head_num
*
layer
.
v_head_dim
)
query
=
q
.
view
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
query
=
q
.
reshape
(
-
1
,
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
if
self
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
if
self
.
forward_metadata
.
seq_lens_cpu_int
is
None
:
actual_seq_len_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
actual_seq_len_kv
=
self
.
forward_metadata
.
seq_lens_cpu_list
else
:
else
:
...
@@ -439,7 +474,8 @@ class AscendAttnBackend(AttentionBackend):
...
@@ -439,7 +474,8 @@ class AscendAttnBackend(AttentionBackend):
scale
=
layer
.
scaling
,
scale
=
layer
.
scaling
,
)
)
else
:
else
:
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
query
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
)
num_tokens
=
query
.
shape
[
0
]
attn_output
=
torch
.
empty
(
attn_output
=
torch
.
empty
(
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
(
num_tokens
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
dtype
=
query
.
dtype
,
dtype
=
query
.
dtype
,
...
...
python/sglang/srt/layers/layernorm.py
View file @
ba861293
...
@@ -53,7 +53,7 @@ elif _is_hip:
...
@@ -53,7 +53,7 @@ elif _is_hip:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
is_npu
()
:
if
_
is_npu
:
import
torch_npu
import
torch_npu
...
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
...
@@ -266,23 +266,48 @@ class GemmaRMSNorm(CustomOp):
out
=
gemma_rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
out
=
gemma_rmsnorm
(
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
)
return
out
return
out
def
forward_npu
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
orig_dtype
=
x
.
dtype
if
residual
is
not
None
:
x
=
x
+
residual
residual
=
x
class
Gemma3RMSNorm
(
nn
.
Module
):
x
=
x
.
float
()
variance
=
torch_npu
.
mean
(
torch_npu
.
pow
(
x
,
2
),
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch_npu
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
(
1.0
+
self
.
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
if
residual
is
None
else
(
x
,
residual
)
class
Gemma3RMSNorm
(
CustomOp
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
):
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
eps
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
# Re-dispatch
def
_norm
(
self
,
x
):
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
def
forward
_native
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
())
output
=
self
.
_norm
(
x
.
float
())
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
# See https://github.com/huggingface/transformers/pull/29402
output
=
output
*
(
1.0
+
self
.
weight
.
float
())
output
=
output
*
(
1.0
+
self
.
weight
.
float
())
return
output
.
type_as
(
x
)
return
output
.
type_as
(
x
)
def
forward_cuda
(
self
,
x
):
return
self
.
forward_native
(
x
)
def
forward_npu
(
self
,
x
):
output
,
_
=
torch_npu
.
npu_gemma_rms_norm
(
x
,
self
.
weight
,
self
.
eps
)
return
output
def
extra_repr
(
self
):
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
eps
}
"
...
...
python/sglang/srt/layers/rotary_embedding.py
View file @
ba861293
...
@@ -1876,7 +1876,7 @@ def rotate_half(x):
...
@@ -1876,7 +1876,7 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
def
apply_rotary_pos_emb
_native
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
...
@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
...
@@ -1899,6 +1899,33 @@ def apply_rotary_pos_emb(
return
q_embed
,
k_embed
return
q_embed
,
k_embed
def
apply_rotary_pos_emb_npu
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
unsqueeze_dim
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
q
.
shape
[
1
]
!=
128
:
return
apply_rotary_pos_emb_native
(
q
,
k
,
cos
,
sin
,
unsqueeze_dim
)
cos
=
cos
.
unsqueeze
(
unsqueeze_dim
)
cos
=
torch
.
transpose
(
cos
,
1
,
2
)
sin
=
sin
.
unsqueeze
(
unsqueeze_dim
)
sin
=
torch
.
transpose
(
sin
,
1
,
2
)
q
=
torch
.
transpose
(
q
,
1
,
2
)
k
=
torch
.
transpose
(
k
,
1
,
2
)
q_embed
,
k_embed
=
torch_npu
.
npu_apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q_embed
=
torch
.
transpose
(
q_embed
,
1
,
2
)
k_embed
=
torch
.
transpose
(
k_embed
,
1
,
2
)
return
q_embed
,
k_embed
if
_is_npu
:
apply_rotary_pos_emb
=
apply_rotary_pos_emb_npu
else
:
apply_rotary_pos_emb
=
apply_rotary_pos_emb_native
def
get_rope_cpu
(
def
get_rope_cpu
(
head_size
:
int
,
head_size
:
int
,
rotary_dim
:
int
,
rotary_dim
:
int
,
...
...
python/sglang/srt/managers/mm_utils.py
View file @
ba861293
...
@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
)
)
from
sglang.srt.mem_cache.multimodal_cache
import
MultiModalCache
from
sglang.srt.mem_cache.multimodal_cache
import
MultiModalCache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
flatten_nested_list
,
print_warning_once
from
sglang.srt.utils
import
flatten_nested_list
,
is_npu
,
print_warning_once
from
sglang.utils
import
logger
from
sglang.utils
import
logger
_is_npu
=
is_npu
()
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
# to ensure consistent logging behavior across the codebase. This prevents issues with log
# to ensure consistent logging behavior across the codebase. This prevents issues with log
# propagation that can cause some log messages (like 'server is fired up') to not appear
# propagation that can cause some log messages (like 'server is fired up') to not appear
...
@@ -486,6 +488,8 @@ def get_embedding_and_mask(
...
@@ -486,6 +488,8 @@ def get_embedding_and_mask(
if
embedding
is
None
:
if
embedding
is
None
:
return
None
,
None
return
None
,
None
# 2. Get mask
# 2. Get mask
if
_is_npu
:
torch
.
npu
.
current_stream
().
synchronize
()
special_multimodal_mask
=
_get_multimodal_mask
(
input_ids
,
placeholder_tensor
)
special_multimodal_mask
=
_get_multimodal_mask
(
input_ids
,
placeholder_tensor
)
# 3. Adjust embedding length if needed
# 3. Adjust embedding length if needed
embedding
=
_adjust_embedding_length
(
embedding
,
special_multimodal_mask
,
logger
)
embedding
=
_adjust_embedding_length
(
embedding
,
special_multimodal_mask
,
logger
)
...
...
python/sglang/srt/multimodal/processors/base_processor.py
View file @
ba861293
...
@@ -13,7 +13,9 @@ from PIL import Image
...
@@ -13,7 +13,9 @@ from PIL import Image
from
transformers
import
BaseImageProcessorFast
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.utils
import
load_audio
,
load_image
,
load_video
,
logger
from
sglang.srt.utils
import
is_npu
,
load_audio
,
load_image
,
load_video
,
logger
_is_npu
=
is_npu
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
isinstance
(
processor
.
image_processor
,
BaseImageProcessorFast
)
and
not
self
.
server_args
.
disable_fast_image_processor
and
not
self
.
server_args
.
disable_fast_image_processor
):
):
kwargs
[
"device"
]
=
"cuda"
kwargs
[
"device"
]
=
"cuda"
if
not
_is_npu
else
"npu"
result
=
processor
.
__call__
(
result
=
processor
.
__call__
(
text
=
[
input_text
],
text
=
[
input_text
],
padding
=
True
,
padding
=
True
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment