Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e22f3a5e
Unverified
Commit
e22f3a5e
authored
Sep 23, 2025
by
ronnie_zheng
Committed by
GitHub
Sep 22, 2025
Browse files
[Ascend]optimize Qwen3 on Ascend (#10574)
Co-authored-by:
c30031083
<
chenxu140@huawei.com
>
parent
095093ee
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
2 deletions
+81
-2
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+8
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+7
-0
python/sglang/srt/model_executor/npu_graph_runner.py
python/sglang/srt/model_executor/npu_graph_runner.py
+2
-0
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+18
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+44
-0
No files found.
python/sglang/srt/layers/communicator.py
View file @
e22f3a5e
...
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
...
@@ -50,6 +50,7 @@ from sglang.srt.utils import (
is_hip
,
is_hip
,
is_sm90_supported
,
is_sm90_supported
,
is_sm100_supported
,
is_sm100_supported
,
prepare_weight_cache
,
)
)
_is_flashinfer_available
=
is_flashinfer_available
()
_is_flashinfer_available
=
is_flashinfer_available
()
...
@@ -275,7 +276,11 @@ class LayerCommunicator:
...
@@ -275,7 +276,11 @@ class LayerCommunicator:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
cache
=
None
,
):
):
if
cache
is
not
None
:
self
.
_context
.
cache
=
cache
return
self
.
_communicate_with_all_reduce_and_layer_norm_fn
(
return
self
.
_communicate_with_all_reduce_and_layer_norm_fn
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
residual
=
residual
,
residual
=
residual
,
...
@@ -349,6 +354,7 @@ class CommunicateContext:
...
@@ -349,6 +354,7 @@ class CommunicateContext:
attn_tp_size
:
int
attn_tp_size
:
int
attn_dp_size
:
int
attn_dp_size
:
int
tp_size
:
int
tp_size
:
int
cache
=
None
def
is_same_group_size
(
self
,
a
:
ScatterMode
,
b
:
ScatterMode
):
def
is_same_group_size
(
self
,
a
:
ScatterMode
,
b
:
ScatterMode
):
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
...
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
...
@@ -533,6 +539,8 @@ class CommunicateWithAllReduceAndLayerNormFn:
)
)
else
:
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
if
context
.
cache
is
not
None
:
_
=
prepare_weight_cache
(
hidden_states
,
context
.
cache
)
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
e22f3a5e
...
@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
...
@@ -638,6 +638,7 @@ class NPU_W8A8LinearMethodImpl:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
layer
.
weight
.
data
=
layer
.
weight
.
data
.
transpose
(
0
,
1
).
contiguous
()
layer
.
weight_scale
.
data
=
torch
.
flatten
(
layer
.
weight_scale
.
data
)
layer
.
weight_scale
.
data
=
torch
.
flatten
(
layer
.
weight_scale
.
data
)
layer
.
weight_offset
.
data
=
torch
.
flatten
(
layer
.
weight_offset
.
data
)
layer
.
weight_offset
.
data
=
torch
.
flatten
(
layer
.
weight_offset
.
data
)
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8LinearMethodMTImpl
:
class
NPU_W8A8LinearMethodMTImpl
:
...
@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
...
@@ -830,6 +831,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer
.
weight_scale
.
data
=
layer
.
weight_scale
.
data
.
flatten
()
layer
.
weight_scale
.
data
=
layer
.
weight_scale
.
data
.
flatten
()
layer
.
weight_scale_fp32
=
layer
.
weight_scale
.
data
.
to
(
torch
.
float32
)
layer
.
weight_scale_fp32
=
layer
.
weight_scale
.
data
.
to
(
torch
.
float32
)
layer
.
weight_offset
.
data
=
layer
.
weight_offset
.
data
.
flatten
()
layer
.
weight_offset
.
data
=
layer
.
weight_offset
.
data
.
flatten
()
layer
.
weight
.
data
=
torch_npu
.
npu_format_cast
(
layer
.
weight
.
data
,
29
)
class
NPU_W8A8DynamicLinearMethod
(
LinearMethodBase
):
class
NPU_W8A8DynamicLinearMethod
(
LinearMethodBase
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e22f3a5e
...
@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
...
@@ -179,6 +179,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
if
_is_npu
:
import
torch_npu
torch
.
npu
.
config
.
allow_internal_format
=
True
torch_npu
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
class
RankZeroFilter
(
logging
.
Filter
):
class
RankZeroFilter
(
logging
.
Filter
):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
...
...
python/sglang/srt/model_executor/npu_graph_runner.py
View file @
e22f3a5e
...
@@ -19,8 +19,10 @@ import logging
...
@@ -19,8 +19,10 @@ import logging
import
threading
import
threading
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch
from
sglang.srt.configs.model_config
import
AttentionArch
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/models/qwen3.py
View file @
e22f3a5e
...
@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -30,12 +30,19 @@ from sglang.srt.model_loader.weight_utils import (
)
)
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.utils
import
add_prefix
,
is_cuda
from
sglang.srt.utils
import
(
add_prefix
,
get_cmo_stream
,
is_cuda
,
is_npu
,
wait_cmo_stream
,
)
Qwen3Config
=
None
Qwen3Config
=
None
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
class
Qwen3Attention
(
nn
.
Module
):
class
Qwen3Attention
(
nn
.
Module
):
...
@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
...
@@ -235,9 +242,18 @@ class Qwen3DecoderLayer(nn.Module):
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
,
cache
=
(
[
self
.
mlp
.
gate_up_proj
.
weight
,
self
.
mlp
.
down_proj
.
weight
]
if
_is_npu
else
None
),
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
_is_npu
and
get_cmo_stream
():
wait_cmo_stream
()
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
)
)
...
...
python/sglang/srt/utils.py
View file @
e22f3a5e
...
@@ -517,6 +517,50 @@ def make_layers(
...
@@ -517,6 +517,50 @@ def make_layers(
return
modules
,
start_layer
,
end_layer
return
modules
,
start_layer
,
end_layer
cmo_stream
=
None
def
get_cmo_stream
():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global
cmo_stream
if
cmo_stream
is
None
:
cmo_stream
=
torch
.
get_device_module
().
Stream
()
return
cmo_stream
def
prepare_weight_cache
(
handle
,
cache
):
import
torch_npu
NPU_PREFETCH_MAX_SIZE_BYTES
=
(
1000000000
# 1GB, a large value to prefetch entire weight
)
stream
=
get_cmo_stream
()
stream
.
wait_stream
(
torch
.
npu
.
current_stream
())
with
torch
.
npu
.
stream
(
stream
):
if
isinstance
(
cache
,
list
):
for
weight
in
cache
:
torch_npu
.
npu_prefetch
(
weight
,
handle
,
NPU_PREFETCH_MAX_SIZE_BYTES
,
)
else
:
torch_npu
.
npu_prefetch
(
cache
,
handle
,
NPU_PREFETCH_MAX_SIZE_BYTES
,
)
def
wait_cmo_stream
():
cur_stream
=
torch
.
get_device_module
().
current_stream
()
cur_stream
.
wait_stream
(
get_cmo_stream
())
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
"""Set the random seed for all libraries."""
"""Set the random seed for all libraries."""
random
.
seed
(
seed
)
random
.
seed
(
seed
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment