Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e22f3a5e
Unverified
Commit
e22f3a5e
authored
Sep 23, 2025
by
ronnie_zheng
Committed by
GitHub
Sep 22, 2025
Browse files
[Ascend]optimize Qwen3 on Ascend (#10574)
Co-authored-by:
c30031083
<
chenxu140@huawei.com
>
parent
095093ee
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
2 deletions
+81
-2
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+8
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-0
python/sglang/srt/model_executor/npu_graph_runner.py
python/sglang/srt/model_executor/npu_graph_runner.py
+2
-0
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+18
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+44
-0
No files found.
python/sglang/srt/layers/communicator.py
View file @
e22f3a5e
...
...
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
is_hip
,
is_sm90_supported
,
is_sm100_supported
,
prepare_weight_cache
,
)
_is_flashinfer_available
=
is_flashinfer_available
()
...
...
@@ -275,7 +276,11 @@ class LayerCommunicator:
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
cache
=
None
,
):
if
cache
is
not
None
:
self
.
_context
.
cache
=
cache
return
self
.
_communicate_with_all_reduce_and_layer_norm_fn
(
hidden_states
=
hidden_states
,
residual
=
residual
,
...
...
@@ -349,6 +354,7 @@ class CommunicateContext:
attn_tp_size
:
int
attn_dp_size
:
int
tp_size
:
int
cache
=
None
def
is_same_group_size
(
self
,
a
:
ScatterMode
,
b
:
ScatterMode
):
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
...
...
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
if
context
.
cache
is
not
None
:
_
=
prepare_weight_cache
(
hidden_states
,
context
.
cache
)
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
return
hidden_states
,
residual
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
e22f3a5e
...
...
@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
layer
.
weight_scale
.
data
=
torch
.
flatten
(
layer
.
weight_scale
.
data
)
layer
.
weight_offset
.
data
=
torch
.
flatten
(
layer
.
weight_offset
.
data
)
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8LinearMethodMTImpl
:
...
...
@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer
.
weight_scale
.
data
=
layer
.
weight_scale
.
data
.
flatten
()
layer
.
weight_scale_fp32
=
layer
.
weight_scale
.
data
.
to
(
torch
.
float32
)
layer
.
weight_offset
.
data
=
layer
.
weight_offset
.
data
.
flatten
()
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8DynamicLinearMethod
(
LinearMethodBase
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e22f3a5e
...
...
@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger
=
logging
.
getLogger
(
__name__
)
if
_is_npu
:
import
torch_npu
torch
.
npu
.
config
.
allow_internal_format
=
True
torch_npu
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
class
RankZeroFilter
(
logging
.
Filter
):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
...
...
python/sglang/srt/model_executor/npu_graph_runner.py
View file @
e22f3a5e
...
...
@@ -19,8 +19,10 @@ import logging
import
threading
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
numpy
as
np
import
torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/models/qwen3.py
View file @
e22f3a5e
...
...
@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
)
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.utils
import
add_prefix
,
is_cuda
from
sglang.srt.utils
import
(
add_prefix
,
get_cmo_stream
,
is_cuda
,
is_npu
,
wait_cmo_stream
,
)
Qwen3Config
=
None
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
class
Qwen3Attention
(
nn
.
Module
):
...
...
@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
# Fully Connected
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
,
cache
=
(
[
self
.
mlp
.
gate_up_proj
.
weight
,
self
.
mlp
.
down_proj
.
weight
]
if
_is_npu
else
None
),
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
_is_npu
and
get_cmo_stream
():
wait_cmo_stream
()
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
...
...
python/sglang/srt/utils.py
View file @
e22f3a5e
...
...
@@ -517,6 +517,50 @@ def make_layers(
return
modules
,
start_layer
,
end_layer
cmo_stream
=
None
def
get_cmo_stream
():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global
cmo_stream
if
cmo_stream
is
None
:
cmo_stream
=
torch
.
get_device_module
().
Stream
()
return
cmo_stream
def
prepare_weight_cache
(
handle
,
cache
):
import
torch_npu
NPU_PREFETCH_MAX_SIZE_BYTES
=
(
1000000000
# 1GB, a large value to prefetch entire weight
)
stream
=
get_cmo_stream
()
stream
.
wait_stream
(
torch
.
npu
.
current_stream
())
with
torch
.
npu
.
stream
(
stream
):
if
isinstance
(
cache
,
list
):
for
weight
in
cache
:
torch_npu
.
npu_prefetch
(
weight
,
handle
,
NPU_PREFETCH_MAX_SIZE_BYTES
,
)
else
:
torch_npu
.
npu_prefetch
(
cache
,
handle
,
NPU_PREFETCH_MAX_SIZE_BYTES
,
)
def
wait_cmo_stream
():
cur_stream
=
torch
.
get_device_module
().
current_stream
()
cur_stream
.
wait_stream
(
get_cmo_stream
())
def
set_random_seed
(
seed
:
int
)
->
None
:
"""Set the random seed for all libraries."""
random
.
seed
(
seed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment