Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8cd42504
Unverified
Commit
8cd42504
authored
Mar 22, 2025
by
Yun Dai
Committed by
GitHub
Mar 22, 2025
Browse files
[quantization] fix channelwise conversion with scalar weight scale (#4596)
parent
6a384d5c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
0 deletions
+58
-0
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+5
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+9
-0
test/srt/test_eval_fp8_accuracy.py
test/srt/test_eval_fp8_accuracy.py
+44
-0
No files found.
python/sglang/srt/layers/quantization/utils.py
View file @
8cd42504
...
@@ -74,6 +74,11 @@ def convert_to_channelwise(
...
@@ -74,6 +74,11 @@ def convert_to_channelwise(
(
sum
(
logical_widths
),
1
),
dtype
=
torch
.
float32
,
device
=
weight_scale
.
device
(
sum
(
logical_widths
),
1
),
dtype
=
torch
.
float32
,
device
=
weight_scale
.
device
)
)
# Handle scalar tensor case: broadcast same scale to all channels
if
weight_scale
.
dim
()
==
0
:
weight_scale_channel
.
fill_
(
weight_scale
.
item
())
return
weight_scale_channel
# Expand each scale to match the size of each logical matrix.
# Expand each scale to match the size of each logical matrix.
start
=
0
start
=
0
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
for
idx
,
logical_width
in
enumerate
(
logical_widths
):
...
...
python/sglang/test/test_utils.py
View file @
8cd42504
...
@@ -33,6 +33,15 @@ DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST = "neuralmagic/Meta-Llama-3-8B-Instruct
...
@@ -33,6 +33,15 @@ DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST = "neuralmagic/Meta-Llama-3-8B-Instruct
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
=
(
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
=
(
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
"neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
)
)
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST
=
(
"nvidia/Llama-3.1-8B-Instruct-FP8"
)
# TODO(yundai424): right now specifying to an older revision since the latest one
# carries kv cache quantization which doesn't work yet
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION
=
(
"13858565416dbdc0b4e7a4a677fadfbd5b9e5bb9"
)
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.1-8B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_MOE_MODEL_NAME_FOR_TEST
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
...
...
test/srt/test_eval_fp8_accuracy.py
View file @
8cd42504
...
@@ -6,6 +6,8 @@ from sglang.test.run_eval import run_eval
...
@@ -6,6 +6,8 @@ from sglang.test.run_eval import run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
...
@@ -105,5 +107,47 @@ class TestEvalFP8DynamicQuantAccuracy(unittest.TestCase):
...
@@ -105,5 +107,47 @@ class TestEvalFP8DynamicQuantAccuracy(unittest.TestCase):
)
)
class
TestEvalFP8ModelOptQuantAccuracy
(
unittest
.
TestCase
):
def
_run_test
(
self
,
model
,
other_args
,
expected_score
):
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
other_args
or
[]
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
try
:
args
=
SimpleNamespace
(
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
expected_score
)
finally
:
kill_process_tree
(
process
.
pid
)
def
test_mmlu_offline_only
(
self
):
"""Test with offline quantization only."""
self
.
_run_test
(
model
=
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST
,
other_args
=
[
"--quantization"
,
"modelopt"
,
"--revision"
,
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION
,
],
expected_score
=
0.64
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment