Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53b018ed
Unverified
Commit
53b018ed
authored
Apr 18, 2024
by
Michael Goin
Committed by
GitHub
Apr 18, 2024
Browse files
[Bugfix] Get available quantization methods from quantization registry (#4098)
parent
66ded030
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
18 additions
and
13 deletions
+18
-13
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+2
-1
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+3
-1
tests/models/test_marlin.py
tests/models/test_marlin.py
+3
-4
vllm/config.py
vllm/config.py
+4
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+4
-3
No files found.
benchmarks/benchmark_latency.py
View file @
53b018ed
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
def
main
(
args
:
argparse
.
Namespace
):
def
main
(
args
:
argparse
.
Namespace
):
...
@@ -101,7 +102,7 @@ if __name__ == '__main__':
...
@@ -101,7 +102,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
parser
.
add_argument
(
'--quantization'
,
'-q'
,
'-q'
,
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
None
)
default
=
None
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
...
...
benchmarks/benchmark_throughput.py
View file @
53b018ed
...
@@ -10,6 +10,8 @@ from tqdm import tqdm
...
@@ -10,6 +10,8 @@ from tqdm import tqdm
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedTokenizerBase
)
PreTrainedTokenizerBase
)
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
def
sample_requests
(
def
sample_requests
(
dataset_path
:
str
,
dataset_path
:
str
,
...
@@ -267,7 +269,7 @@ if __name__ == "__main__":
...
@@ -267,7 +269,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
parser
.
add_argument
(
'--quantization'
,
'-q'
,
'-q'
,
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
None
)
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
parser
.
add_argument
(
"--n"
,
...
...
tests/models/test_marlin.py
View file @
53b018ed
...
@@ -16,13 +16,12 @@ from dataclasses import dataclass
...
@@ -16,13 +16,12 @@ from dataclasses import dataclass
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.layers.quantization
import
(
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
_QUANTIZATION_CONFIG_REGISTRY
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
marlin_not_supported
=
(
marlin_not_supported
=
(
capability
<
capability
<
_QUANTIZATION_CONFIG_REGISTRY
[
"marlin"
].
get_min_capability
())
QUANTIZATION_METHODS
[
"marlin"
].
get_min_capability
())
@
dataclass
@
dataclass
...
...
vllm/config.py
View file @
53b018ed
...
@@ -9,6 +9,7 @@ from packaging.version import Version
...
@@ -9,6 +9,7 @@ from packaging.version import Version
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
(
get_cpu_memory
,
get_nvcc_cuda_version
,
is_cpu
,
is_hip
,
from
vllm.utils
import
(
get_cpu_memory
,
get_nvcc_cuda_version
,
is_cpu
,
is_hip
,
is_neuron
)
is_neuron
)
...
@@ -118,8 +119,8 @@ class ModelConfig:
...
@@ -118,8 +119,8 @@ class ModelConfig:
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
,
"marlin"
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_
not_
supported_quantization
=
[
"
aw
q"
,
"
marlin
"
]
rocm_supported_quantization
=
[
"
gpt
q"
,
"
squeezellm
"
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
@@ -155,7 +156,7 @@ class ModelConfig:
...
@@ -155,7 +156,7 @@ class ModelConfig:
f
"Unknown quantization method:
{
self
.
quantization
}
. Must "
f
"Unknown quantization method:
{
self
.
quantization
}
. Must "
f
"be one of
{
supported_quantization
}
."
)
f
"be one of
{
supported_quantization
}
."
)
if
is_hip
(
if
is_hip
(
)
and
self
.
quantization
in
rocm_
not_
supported_quantization
:
)
and
self
.
quantization
not
in
rocm_supported_quantization
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
f
"supported in ROCm."
)
...
...
vllm/engine/arg_utils.py
View file @
53b018ed
...
@@ -7,6 +7,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
...
@@ -7,6 +7,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
,
VisionLanguageConfig
)
TokenizerPoolConfig
,
VisionLanguageConfig
)
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.utils
import
str_to_int_tuple
from
vllm.utils
import
str_to_int_tuple
...
@@ -286,7 +287,7 @@ class EngineArgs:
...
@@ -286,7 +287,7 @@ class EngineArgs:
parser
.
add_argument
(
'--quantization'
,
parser
.
add_argument
(
'--quantization'
,
'-q'
,
'-q'
,
type
=
str
,
type
=
str
,
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
None
],
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
EngineArgs
.
quantization
,
default
=
EngineArgs
.
quantization
,
help
=
'Method used to quantize the weights. If '
help
=
'Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
'None, we first check the `quantization_config` '
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
53b018ed
...
@@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
...
@@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
_
QUANTIZATION_
CONFIG_REGISTRY
=
{
QUANTIZATION_
METHODS
=
{
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
...
@@ -16,12 +16,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
...
@@ -16,12 +16,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
_
QUANTIZATION_
CONFIG_REGISTRY
:
if
quantization
not
in
QUANTIZATION_
METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
_
QUANTIZATION_
CONFIG_REGISTRY
[
quantization
]
return
QUANTIZATION_
METHODS
[
quantization
]
__all__
=
[
__all__
=
[
"QuantizationConfig"
,
"QuantizationConfig"
,
"get_quantization_config"
,
"get_quantization_config"
,
"QUANTIZATION_METHODS"
,
]
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment