Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cc0485be
"vscode:/vscode.git/clone" did not exist on "86a14cbad46f6f026ffcee7f504ffaca8da33929"
Unverified
Commit
cc0485be
authored
Jan 14, 2025
by
Ke Bao
Committed by
GitHub
Jan 14, 2025
Browse files
Support w8a8 int8 quantization config (#2881)
parent
b8cd09f2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
135 additions
and
6 deletions
+135
-6
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+15
-6
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+2
-0
python/sglang/srt/layers/quantization/w8a8_int8.py
python/sglang/srt/layers/quantization/w8a8_int8.py
+117
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
cc0485be
...
...
@@ -223,7 +223,11 @@ class ModelConfig:
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
,
"w8a8_int8"
,
]
compatible_quantization_methods
=
{
"w8a8_int8"
:
[
"compressed-tensors"
,
"compressed_tensors"
]
}
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
@@ -247,12 +251,17 @@ class ModelConfig:
if
self
.
quantization
is
None
:
self
.
quantization
=
quant_method
elif
self
.
quantization
!=
quant_method
:
raise
ValueError
(
"Quantization method specified in the model config "
f
"(
{
quant_method
}
) does not match the quantization "
f
"method specified in the `quantization` argument "
f
"(
{
self
.
quantization
}
)."
)
if
(
self
.
quantization
not
in
compatible_quantization_methods
or
quant_method
not
in
compatible_quantization_methods
[
self
.
quantization
]
):
raise
ValueError
(
"Quantization method specified in the model config "
f
"(
{
quant_method
}
) does not match the quantization "
f
"method specified in the `quantization` argument "
f
"(
{
self
.
quantization
}
)."
)
if
self
.
quantization
is
not
None
:
if
self
.
quantization
not
in
supported_quantization
:
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
cc0485be
...
...
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8
import
Fp8Config
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp8Config
from
sglang.srt.layers.quantization.w8a8_int8
import
W8A8Int8Config
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
...
...
@@ -42,6 +43,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"w8a8_int8"
:
W8A8Int8Config
,
}
...
...
python/sglang/srt/layers/quantization/w8a8_int8.py
0 → 100644
View file @
cc0485be
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
sglang.srt.utils
import
is_cuda_available
is_cuda
=
is_cuda_available
()
if
is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.linear
import
LinearMethodBase
from
sglang.srt.layers.parameter
import
ChannelQuantScaleParameter
,
ModelWeightParameter
from
sglang.srt.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
class
W8A8Int8Config
(
QuantizationConfig
):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"w8a8_int8"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"W8A8Int8Config"
:
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.model_executor.layers.linear
import
LinearBase
if
isinstance
(
layer
,
LinearBase
):
return
W8A8Int8LinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
W8A8Int8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quantization_config
:
W8A8Int8Config
):
self
.
quantization_config
=
quantization_config
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
self
.
logical_widths
=
output_partition_sizes
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
return
int8_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
python/sglang/srt/server_args.py
View file @
cc0485be
...
...
@@ -378,6 +378,7 @@ class ServerArgs:
"bitsandbytes"
,
"gguf"
,
"modelopt"
,
"w8a8_int8"
,
],
help
=
"The quantization method."
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment