Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
32eb0da8
Unverified
Commit
32eb0da8
authored
Jan 19, 2025
by
yancong
Committed by
GitHub
Jan 18, 2025
Browse files
[Misc] Support register quantization method out-of-tree (#11969)
parent
6d0e3d37
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
0 deletions
+158
-0
tests/quantization/test_register_quantization_config.py
tests/quantization/test_register_quantization_config.py
+117
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+41
-0
No files found.
tests/quantization/test_register_quantization_config.py
0 → 100644
View file @
32eb0da8
"""Tests register custom quantization config.
See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
from
typing
import
Any
,
Dict
,
List
,
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.linear
import
LinearBase
# noqa: E501
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization
import
(
get_quantization_config
,
register_quantization_config
)
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
)
class
FakeQuantLinearMethod
(
UnquantizedLinearMethod
):
"""Fake quantization linear method for per-token dynamic quantization."""
def
__init__
(
self
,
num_bits
:
int
=
8
)
->
None
:
"""Initialize the quantization method."""
super
().
__init__
()
self
.
num_bits
=
num_bits
def
apply
(
self
,
layer
:
"torch.nn.Module"
,
x
:
"torch.Tensor"
,
bias
:
Optional
[
"torch.Tensor"
]
=
None
)
->
"torch.Tensor"
:
"""Perform fake quantization before the linear layer."""
# Calculate the scales dynamically
max_val
=
torch
.
amax
(
x
,
dim
=
(
0
,
-
1
),
keepdims
=
True
)
min_val
=
torch
.
amin
(
x
,
dim
=
(
0
,
-
1
),
keepdims
=
True
)
scales
=
(
max_val
-
min_val
)
/
(
2
**
self
.
num_bits
-
1
)
# Fake quantize the input
quant_x
=
torch
.
clamp
(
torch
.
round
(
x
/
scales
),
-
2
**
(
self
.
num_bits
-
1
),
2
**
(
self
.
num_bits
-
1
)
-
1
)
dequant_x
=
quant_x
*
scales
return
F
.
linear
(
dequant_x
,
layer
.
weight
,
bias
)
@
register_quantization_config
(
"custom_quant"
)
class
CustomQuantConfig
(
QuantizationConfig
):
"""Custom quantization config for per-token dynamic fake quantization."""
def
__init__
(
self
,
num_bits
:
int
=
8
)
->
None
:
"""Initialize the quantization config."""
self
.
num_bits
=
num_bits
def
get_name
(
self
)
->
str
:
"""Name of the quantization method."""
return
"custom_quant"
def
get_supported_act_dtypes
(
self
)
->
List
[
"torch.dtype"
]:
"""List of supported activation dtypes."""
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
"""Minimum GPU capability to support the quantization method."""
return
-
1
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
"""List of filenames to search for in the model directory."""
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CustomQuantConfig"
:
"""Create a config class from the model's quantization config."""
return
CustomQuantConfig
(
num_bits
=
config
.
get
(
"num_bits"
,
8
))
def
get_quant_method
(
self
,
layer
:
"torch.nn.Module"
,
prefix
:
str
)
->
Optional
[
"FakeQuantLinearMethod"
]:
"""Get the quantize method to use for the quantized layer."""
if
isinstance
(
layer
,
LinearBase
):
return
FakeQuantLinearMethod
(
num_bits
=
self
.
num_bits
)
return
None
def
test_register_quantization_config
():
"""Test register custom quantization config."""
# The quantization method `custom_quant` should be registered.
assert
get_quantization_config
(
"custom_quant"
)
==
CustomQuantConfig
# The quantization method `custom_quant` is already exists,
# should raise an error.
with
pytest
.
raises
(
ValueError
):
register_quantization_config
(
"custom_quant"
)(
CustomQuantConfig
)
@
pytest
.
mark
.
parametrize
(
argnames
=
"model"
,
argvalues
=
[
"meta-llama/Meta-Llama-3-8B-Instruct"
,
])
def
test_custom_quant
(
vllm_runner
,
model
):
"""Test infer with the custom quantization method."""
with
vllm_runner
(
model_name
=
model
,
quantization
=
"custom_quant"
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
# Check the quantization method is FakeQuantLinearMethod
assert
isinstance
(
qkv_proj
.
quant_method
,
FakeQuantLinearMethod
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
vllm/model_executor/layers/quantization/__init__.py
View file @
32eb0da8
...
...
@@ -29,6 +29,45 @@ QUANTIZATION_METHODS: List[str] = [
"quark"
]
# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
=
{}
def
register_quantization_config
(
quantization
:
str
):
"""Register a customized vllm quantization config.
When a quantization method is not supported by vllm, you can register a customized
quantization config to support it.
Args:
quantization (str): The quantization method name.
Examples:
>>> from vllm.model_executor.layers.quantization import register_quantization_config
>>> from vllm.model_executor.layers.quantization import get_quantization_config
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
>>>
>>> @register_quantization_config("my_quant")
... class MyQuantConfig(QuantizationConfig):
... pass
>>>
>>> get_quantization_config("my_quant")
<class 'MyQuantConfig'>
"""
# noqa: E501
def
_wrapper
(
quant_config_cls
):
if
quantization
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"The quantization method `
{
quantization
}
` is already exists."
)
if
not
issubclass
(
quant_config_cls
,
QuantizationConfig
):
raise
ValueError
(
"The quantization config must be a subclass of "
"`QuantizationConfig`."
)
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
[
quantization
]
=
quant_config_cls
QUANTIZATION_METHODS
.
append
(
quantization
)
return
quant_config_cls
return
_wrapper
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
...
...
@@ -84,6 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex"
:
IPEXConfig
,
"quark"
:
QuarkConfig
}
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
return
method_to_config
[
quantization
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment