Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ef3c2dd0
Unverified
Commit
ef3c2dd0
authored
Mar 17, 2025
by
Stefan He
Committed by
GitHub
Mar 17, 2025
Browse files
Support Online Quantization for W8A8 (#4485)
parent
75b65648
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
122 additions
and
40 deletions
+122
-40
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+65
-19
test/srt/test_eval_fp8_accuracy.py
test/srt/test_eval_fp8_accuracy.py
+57
-21
No files found.
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
ef3c2dd0
...
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_utils
import
(
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
...
@@ -22,12 +24,24 @@ _is_hip = is_hip()
...
@@ -22,12 +24,24 @@ _is_hip = is_hip()
class
W8A8Fp8Config
(
QuantizationConfig
):
class
W8A8Fp8Config
(
QuantizationConfig
):
"""Config class for W8A8 FP8 Quantization.
"""Config class for W8A8 FP8 Quantization.
- Weight: static, per-channel, symmetric
Weight Quantization:
- Activation: dynamic, per-token, symmetric
- Method: Static quantization
- Granularity: Per-channel
- Type: Symmetric
Activation Quantization:
- Method: Dynamic quantization
- Granularity: Per-token
- Type: Symmetric
Note:
- For models without offline quantization, weights will be quantized during model loading
- If CUTLASS is supported: Per-channel weight quantization is used
- If CUTLASS is not supported: Falls back to per-token weight quantization
"""
"""
def
__init__
(
self
):
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
):
pass
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
@
classmethod
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
...
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
...
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Fp8Config"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Fp8Config"
:
return
cls
()
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
"compressed-tensors"
in
quant_method
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
)
def
get_quant_method
(
def
get_quant_method
(
self
,
self
,
...
@@ -72,13 +88,35 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -72,13 +88,35 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
.
detach
()
if
_is_hip
:
if
self
.
quantization_config
.
is_checkpoint_fp8_serialized
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight_scale
=
layer
.
weight_scale
.
detach
()
weight
=
weight
,
weight_scale
=
weight_scale
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
)
if
_is_hip
:
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight
=
weight
,
weight_scale
=
weight_scale
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
else
:
# If checkpoint not offline quantized, quantize the weights with per-channel quantization.
if
self
.
cutlass_fp8_supported
:
# if cutlass supported, we use cutlass_scaled_mm
# which requires per-channel quantization on weight
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
weight_scale
=
weight_scale
.
t
().
contiguous
()
else
:
# if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight
qweight
,
weight_scale
=
input_to_float8
(
layer
.
weight
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -90,6 +128,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -90,6 +128,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
**
extra_weight_attrs
):
):
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quantization_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
self
.
logical_widths
=
output_partition_sizes
...
@@ -98,7 +141,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -98,7 +141,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
data
=
torch
.
empty
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
sum
(
output_partition_sizes
),
input_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
,
dtype
=
weight_dtype
,
),
),
input_dim
=
1
,
input_dim
=
1
,
output_dim
=
0
,
output_dim
=
0
,
...
@@ -106,12 +149,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -106,12 +149,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
)
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
if
self
.
quantization_config
.
is_checkpoint_fp8_serialized
:
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
output_dim
=
0
,
)
weight_loader
=
weight_loader
,
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
else
:
layer
.
weight_scale
=
None
def
apply
(
def
apply
(
self
,
self
,
...
...
test/srt/test_eval_fp8_accuracy.py
View file @
ef3c2dd0
...
@@ -6,6 +6,7 @@ from sglang.test.run_eval import run_eval
...
@@ -6,6 +6,7 @@ from sglang.test.run_eval import run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
popen_launch_server
,
...
@@ -40,33 +41,68 @@ class TestEvalFP8Accuracy(unittest.TestCase):
...
@@ -40,33 +41,68 @@ class TestEvalFP8Accuracy(unittest.TestCase):
class
TestEvalFP8DynamicQuantAccuracy
(
unittest
.
TestCase
):
class
TestEvalFP8DynamicQuantAccuracy
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
def
_run_test
(
self
,
model
,
other_args
,
expected_score
):
cls
.
model
=
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
other_args
or
[]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
process
=
popen_launch_server
(
cls
.
base_url
,
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--quantization"
,
"w8a8_fp8"
]
,
other_args
=
other_args
,
)
)
@
classmethod
try
:
def
tearDownClass
(
cls
):
args
=
SimpleNamespace
(
kill_process_tree
(
cls
.
process
.
pid
)
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
def
test_mmlu
(
self
):
metrics
=
run_eval
(
args
)
args
=
SimpleNamespace
(
self
.
assertGreaterEqual
(
metrics
[
"score"
],
expected_score
)
base_url
=
self
.
base_url
,
finally
:
model
=
self
.
model
,
kill_process_tree
(
process
.
pid
)
eval_name
=
"mmlu"
,
num_examples
=
64
,
def
test_mmlu_offline_only
(
self
):
num_threads
=
32
,
"""Test with offline quantization only."""
temperature
=
0.1
,
self
.
_run_test
(
model
=
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
other_args
=
[],
expected_score
=
0.64
,
)
)
metrics
=
run_eval
(
args
)
def
test_mmlu_offline_and_online_override
(
self
):
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.70
)
"""Test with both offline and online quantization."""
self
.
_run_test
(
model
=
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
other_args
=
[
"--quantization"
,
"w8a8_fp8"
],
# inference will use sgl kernel w/ online quant override
# we observed that the accuracy is higher then offline only
expected_score
=
0.64
,
)
def
test_mmlu_online_only
(
self
):
"""Test with online quantization only."""
self
.
_run_test
(
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
# inference will use sgl kernel w/ online quantization only
# we observed that the accuracy is higher then offline only
other_args
=
[
"--quantization"
,
"w8a8_fp8"
],
expected_score
=
0.64
,
)
def
test_mmlu_fp16_baseline
(
self
):
"""Test with unquantized fp16 baseline."""
self
.
_run_test
(
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
other_args
=
[],
expected_score
=
0.64
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment