Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
83123f48
Unverified
Commit
83123f48
authored
Aug 12, 2025
by
ichernob
Committed by
GitHub
Aug 12, 2025
Browse files
[Quantization] Supported w8a8 int8 quantized Gemma3 and Qwen-VL models (#8619)
Co-authored-by:
ronnie_zheng
<
zl19940307@163.com
>
parent
48afa8f1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
131 additions
and
9 deletions
+131
-9
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+9
-3
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+18
-6
test/srt/test_ascend_w8a8_quantization.py
test/srt/test_ascend_w8a8_quantization.py
+104
-0
No files found.
python/sglang/srt/layers/quantization/w8a8_int8.py
View file @
83123f48
...
...
@@ -255,17 +255,23 @@ class W8A8Int8Config(QuantizationConfig):
if
_is_npu
:
if
isinstance
(
layer
,
LinearBase
):
key
=
"model"
if
"vision_model"
in
prefix
:
key
=
"vision_model"
elif
"visual"
in
prefix
:
key
=
"visual"
packed_modules_mapping_subset
=
self
.
packed_modules_mapping
.
get
(
key
,
{})
prefix_in_quant_config
=
prefix
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
self
.
packed_modules_mapping
:
if
proj_name
in
packed_modules_mapping
_subset
:
prefix_in_quant_config
=
prefix
.
replace
(
proj_name
,
self
.
packed_modules_mapping
[
proj_name
][
0
]
proj_name
,
packed_modules_mapping
_subset
[
proj_name
][
0
]
)
self
.
is_dynamic
=
(
self
.
quant_description
[
prefix_in_quant_config
+
".weight"
]
==
"W8A8_DYNAMIC"
)
if
self
.
is_layer_skipped
(
prefix
,
self
.
packed_modules_mapping
):
if
self
.
is_layer_skipped
(
prefix
,
packed_modules_mapping
_subset
):
return
UnquantizedLinearMethod
()
return
(
NPU_W8A8DynamicLinearMethod
(
self
)
...
...
python/sglang/srt/model_loader/loader.py
View file @
83123f48
...
...
@@ -162,12 +162,24 @@ def _initialize_model(
model_class
,
_
=
get_model_architecture
(
model_config
)
packed_modules_mapping
=
getattr
(
model_class
,
"packed_modules_mapping"
,
{})
if
_is_npu
:
packed_modules_mapping
[
"fused_qkv_a_proj_with_mqa"
]
=
[
"q_a_proj"
,
"kv_a_proj_with_mqa"
,
]
packed_modules_mapping
[
"qkv_proj"
]
=
[
"q_proj"
,
"k_proj"
,
"v_proj"
]
packed_modules_mapping
[
"gate_up_proj"
]
=
[
"gate_proj"
,
"up_proj"
]
packed_modules_mapping
.
update
(
{
"visual"
:
{
"qkv_proj"
:
[
"qkv"
]},
"vision_model"
:
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"proj"
:
[
"out_proj"
],
},
"model"
:
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
"fused_qkv_a_proj_with_mqa"
:
[
"q_a_proj"
,
"kv_a_proj_with_mqa"
,
],
},
}
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
,
packed_modules_mapping
)
...
...
test/srt/test_ascend_w8a8_quantization.py
0 → 100644
View file @
83123f48
"""
Usage:
python3 -m unittest test_ascend_w8a8_quantization.TestAscendW8A8.test_gsm8k
"""
import
os
import
time
import
unittest
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
is_in_ci
,
popen_launch_server
,
)
if
"ASCEND_RT_VISIBLE_DEVICES"
not
in
os
.
environ
:
os
.
environ
[
"ASCEND_RT_VISIBLE_DEVICES"
]
=
"0,1"
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
=
(
7000
+
int
(
os
.
environ
.
get
(
"ASCEND_RT_VISIBLE_DEVICES"
,
"0"
)[
0
])
*
100
)
DEFAULT_URL_FOR_TEST
=
f
"http://127.0.0.1:
{
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
+
1000
}
"
class
TestAscendW8A8
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--disable-cuda-graph"
,
"--device"
,
"npu"
,
"--attention-backend"
,
"ascend"
,
"--quantization"
,
"w8a8_int8"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
base_url
=
DEFAULT_URL_FOR_TEST
url
=
urlparse
(
base_url
)
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
f
"http://
{
url
.
hostname
}
"
,
port
=
int
(
url
.
port
),
)
metrics
=
run_eval
(
args
)
print
(
metrics
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
0.25
)
self
.
assertGreaterEqual
(
metrics
[
"output_throughput"
],
1000
)
def
run_decode
(
self
,
max_new_tokens
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
},
"ignore_eos"
:
True
,
},
)
return
response
.
json
()
def
test_throughput
(
self
):
max_tokens
=
256
tic
=
time
.
perf_counter
()
res
=
self
.
run_decode
(
max_tokens
)
tok
=
time
.
perf_counter
()
print
(
res
[
"text"
])
throughput
=
max_tokens
/
(
tok
-
tic
)
print
(
f
"Throughput:
{
throughput
}
tokens/s"
)
if
is_in_ci
():
self
.
assertGreaterEqual
(
throughput
,
25
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment