Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ef3c2dd0
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "06093fd3967de404001639728926f73173d745ae"
Unverified
Commit
ef3c2dd0
authored
Mar 17, 2025
by
Stefan He
Committed by
GitHub
Mar 17, 2025
Browse files
Support Online Quantization for W8A8 (#4485)
parent
75b65648
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
122 additions
and
40 deletions
+122
-40
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+65
-19
test/srt/test_eval_fp8_accuracy.py
test/srt/test_eval_fp8_accuracy.py
+57
-21
No files found.
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
ef3c2dd0
...
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -9,9 +9,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_utils
import
(
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_fp8_linear
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
...
@@ -22,12 +24,24 @@ _is_hip = is_hip()
...
@@ -22,12 +24,24 @@ _is_hip = is_hip()
class
W8A8Fp8Config
(
QuantizationConfig
):
class
W8A8Fp8Config
(
QuantizationConfig
):
"""Config class for W8A8 FP8 Quantization.
"""Config class for W8A8 FP8 Quantization.
- Weight: static, per-channel, symmetric
Weight Quantization:
- Activation: dynamic, per-token, symmetric
- Method: Static quantization
- Granularity: Per-channel
- Type: Symmetric
Activation Quantization:
- Method: Dynamic quantization
- Granularity: Per-token
- Type: Symmetric
Note:
- For models without offline quantization, weights will be quantized during model loading
- If CUTLASS is supported: Per-channel weight quantization is used
- If CUTLASS is not supported: Falls back to per-token weight quantization
"""
"""
def
__init__
(
self
):
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
):
pass
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
@
classmethod
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
...
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
...
@@ -47,7 +61,9 @@ class W8A8Fp8Config(QuantizationConfig):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Fp8Config"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Fp8Config"
:
return
cls
()
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
"compressed-tensors"
in
quant_method
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
)
def
get_quant_method
(
def
get_quant_method
(
self
,
self
,
...
@@ -72,13 +88,35 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -72,13 +88,35 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
.
detach
()
if
_is_hip
:
if
self
.
quantization_config
.
is_checkpoint_fp8_serialized
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight_scale
=
layer
.
weight_scale
.
detach
()
weight
=
weight
,
weight_scale
=
weight_scale
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
)
if
_is_hip
:
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight
=
weight
,
weight_scale
=
weight_scale
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
else
:
# If checkpoint not offline quantized, quantize the weights with per-channel quantization.
if
self
.
cutlass_fp8_supported
:
# if cutlass supported, we use cutlass_scaled_mm
# which requires per-channel quantization on weight
qweight
,
weight_scale
=
per_token_group_quant_fp8
(
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
weight_scale
=
weight_scale
.
t
().
contiguous
()
else
:
# if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight
qweight
,
weight_scale
=
input_to_float8
(
layer
.
weight
)
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -90,6 +128,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -90,6 +128,11 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
**
extra_weight_attrs
):
):
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quantization_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
self
.
logical_widths
=
output_partition_sizes
...
@@ -98,7 +141,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -98,7 +141,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
data
=
torch
.
empty
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
sum
(
output_partition_sizes
),
input_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
,
dtype
=
weight_dtype
,
),
),
input_dim
=
1
,
input_dim
=
1
,
output_dim
=
0
,
output_dim
=
0
,
...
@@ -106,12 +149,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
...
@@ -106,12 +149,15 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
)
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
if
self
.
quantization_config
.
is_checkpoint_fp8_serialized
:
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
output_dim
=
0
,
)
weight_loader
=
weight_loader
,
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
else
:
layer
.
weight_scale
=
None
def
apply
(
def
apply
(
self
,
self
,
...
...
test/srt/test_eval_fp8_accuracy.py
View file @
ef3c2dd0
...
@@ -6,6 +6,7 @@ from sglang.test.run_eval import run_eval
...
@@ -6,6 +6,7 @@ from sglang.test.run_eval import run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
popen_launch_server
,
...
@@ -40,33 +41,68 @@ class TestEvalFP8Accuracy(unittest.TestCase):
...
@@ -40,33 +41,68 @@ class TestEvalFP8Accuracy(unittest.TestCase):
class
TestEvalFP8DynamicQuantAccuracy
(
unittest
.
TestCase
):
class
TestEvalFP8DynamicQuantAccuracy
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
def
_run_test
(
self
,
model
,
other_args
,
expected_score
):
cls
.
model
=
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
other_args
=
other_args
or
[]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
process
=
popen_launch_server
(
cls
.
base_url
,
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--quantization"
,
"w8a8_fp8"
]
,
other_args
=
other_args
,
)
)
@
classmethod
try
:
def
tearDownClass
(
cls
):
args
=
SimpleNamespace
(
kill_process_tree
(
cls
.
process
.
pid
)
base_url
=
base_url
,
model
=
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
def
test_mmlu
(
self
):
metrics
=
run_eval
(
args
)
args
=
SimpleNamespace
(
self
.
assertGreaterEqual
(
metrics
[
"score"
],
expected_score
)
base_url
=
self
.
base_url
,
finally
:
model
=
self
.
model
,
kill_process_tree
(
process
.
pid
)
eval_name
=
"mmlu"
,
num_examples
=
64
,
def
test_mmlu_offline_only
(
self
):
num_threads
=
32
,
"""Test with offline quantization only."""
temperature
=
0.1
,
self
.
_run_test
(
model
=
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
other_args
=
[],
expected_score
=
0.64
,
)
)
metrics
=
run_eval
(
args
)
def
test_mmlu_offline_and_online_override
(
self
):
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.70
)
"""Test with both offline and online quantization."""
self
.
_run_test
(
model
=
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST
,
other_args
=
[
"--quantization"
,
"w8a8_fp8"
],
# inference will use sgl kernel w/ online quant override
# we observed that the accuracy is higher then offline only
expected_score
=
0.64
,
)
def
test_mmlu_online_only
(
self
):
"""Test with online quantization only."""
self
.
_run_test
(
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
# inference will use sgl kernel w/ online quantization only
# we observed that the accuracy is higher then offline only
other_args
=
[
"--quantization"
,
"w8a8_fp8"
],
expected_score
=
0.64
,
)
def
test_mmlu_fp16_baseline
(
self
):
"""Test with unquantized fp16 baseline."""
self
.
_run_test
(
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
other_args
=
[],
expected_score
=
0.64
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment