Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cc0485be
Unverified
Commit
cc0485be
authored
Jan 14, 2025
by
Ke Bao
Committed by
GitHub
Jan 14, 2025
Browse files
Support w8a8 int8 quantization config (#2881)
parent
b8cd09f2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
6 deletions
+135
-6
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+15
-6
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+117
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
cc0485be
...
@@ -223,7 +223,11 @@ class ModelConfig:
...
@@ -223,7 +223,11 @@ class ModelConfig:
"compressed_tensors"
,
"compressed_tensors"
,
"compressed-tensors"
,
"compressed-tensors"
,
"experts_int8"
,
"experts_int8"
,
"w8a8_int8"
,
]
]
compatible_quantization_methods
=
{
"w8a8_int8"
:
[
"compressed-tensors"
,
"compressed_tensors"
]
}
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
@@ -247,12 +251,17 @@ class ModelConfig:
...
@@ -247,12 +251,17 @@ class ModelConfig:
if
self
.
quantization
is
None
:
if
self
.
quantization
is
None
:
self
.
quantization
=
quant_method
self
.
quantization
=
quant_method
elif
self
.
quantization
!=
quant_method
:
elif
self
.
quantization
!=
quant_method
:
raise
ValueError
(
if
(
"Quantization method specified in the model config "
self
.
quantization
not
in
compatible_quantization_methods
f
"(
{
quant_method
}
) does not match the quantization "
or
quant_method
f
"method specified in the `quantization` argument "
not
in
compatible_quantization_methods
[
self
.
quantization
]
f
"(
{
self
.
quantization
}
)."
):
)
raise
ValueError
(
"Quantization method specified in the model config "
f
"(
{
quant_method
}
) does not match the quantization "
f
"method specified in the `quantization` argument "
f
"(
{
self
.
quantization
}
)."
)
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
if
self
.
quantization
not
in
supported_quantization
:
if
self
.
quantization
not
in
supported_quantization
:
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
cc0485be
...
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
...
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"aqlm"
:
AQLMConfig
,
...
@@ -42,6 +43,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -42,6 +43,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"bitsandbytes"
:
BitsAndBytesConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"experts_int8"
:
ExpertsInt8Config
,
"w8a8_int8"
:
W8A8Int8Config
,
}
}
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
0 → 100644
View file @
cc0485be
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt.utils
import
is_cuda_available
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearMethodBase
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
class
W8A8Int8Config
(
QuantizationConfig
):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"w8a8_int8"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Int8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.model_executor.layers.linear
import
LinearBase
if
isinstance
(
layer
,
LinearBase
):
return
W8A8Int8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
W8A8Int8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
W8A8Int8Config
):
self
.
quantization_config
=
quantization_config
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
int8_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
python/sglang/srt/server_args.py
View file @
cc0485be
...
@@ -378,6 +378,7 @@ class ServerArgs:
...
@@ -378,6 +378,7 @@ class ServerArgs:
"bitsandbytes"
,
"bitsandbytes"
,
"gguf"
,
"gguf"
,
"modelopt"
,
"modelopt"
,
"w8a8_int8"
,
],
],
help
=
"The quantization method."
,
help
=
"The quantization method."
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment