Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
bb00f66e
Unverified
Commit
bb00f66e
authored
Nov 17, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 17, 2023
Browse files
Use `quantization_config` in hf config (#1695)
parent
e87557b0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
10 deletions
+34
-10
vllm/config.py
vllm/config.py
+24
-8
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+1
-0
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+9
-2
No files found.
vllm/config.py
View file @
bb00f66e
...
...
@@ -104,14 +104,30 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
"awq"
,
"squeezellm"
]
if
self
.
quantization
is
None
:
return
quantization
=
self
.
quantization
.
lower
()
if
quantization
not
in
supported_quantization
:
raise
ValueError
(
f
"Unknown quantization:
{
self
.
quantization
}
. Must be one of "
f
"
{
supported_quantization
}
."
)
self
.
quantization
=
quantization
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
# Parse quantization method from the HF model config, if available.
hf_quant_config
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
not
None
:
hf_quant_method
=
str
(
hf_quant_config
[
"quant_method"
]).
lower
()
if
self
.
quantization
is
None
:
self
.
quantization
=
hf_quant_method
elif
self
.
quantization
!=
hf_quant_method
:
raise
ValueError
(
"Quantization method specified in the model config "
f
"(
{
hf_quant_method
}
) does not match the quantization "
f
"method specified in the `quantization` argument "
f
"(
{
self
.
quantization
}
)."
)
if
self
.
quantization
is
not
None
:
if
self
.
quantization
not
in
supported_quantization
:
raise
ValueError
(
f
"Unknown quantization method:
{
self
.
quantization
}
. Must "
f
"be one of
{
supported_quantization
}
."
)
logger
.
warning
(
f
"
{
self
.
quantization
}
quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models."
)
def
verify_with_parallel_config
(
self
,
...
...
vllm/model_executor/model_loader.py
View file @
bb00f66e
...
...
@@ -66,6 +66,7 @@ def get_model(model_config: ModelConfig) -> nn.Module:
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
.
quantization
,
model_config
.
model
,
model_config
.
hf_config
,
model_config
.
download_dir
)
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
...
...
vllm/model_executor/weight_utils.py
View file @
bb00f66e
...
...
@@ -7,9 +7,10 @@ from collections import defaultdict
from
typing
import
Any
,
Iterator
,
List
,
Optional
,
Tuple
from
huggingface_hub
import
snapshot_download
from
safetensors.torch
import
load_file
,
save_file
,
safe_open
import
numpy
as
np
from
safetensors.torch
import
load_file
,
save_file
,
safe_open
import
torch
from
transformers
import
PretrainedConfig
from
tqdm.auto
import
tqdm
from
vllm.logger
import
init_logger
...
...
@@ -84,8 +85,15 @@ def convert_bin_to_safetensor_file(
def
get_quant_config
(
quantization
:
str
,
model_name_or_path
:
str
,
hf_config
:
PretrainedConfig
,
cache_dir
:
Optional
[
str
]
=
None
,
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
quantization
)
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
hf_config
,
"quantization_config"
,
None
)
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
# Download the config files.
...
...
@@ -98,7 +106,6 @@ def get_quant_config(
hf_folder
=
model_name_or_path
config_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
"*.json"
))
quant_cls
=
get_quantization_config
(
quantization
)
quant_config_files
=
[
f
for
f
in
config_files
if
any
(
f
.
endswith
(
x
)
for
x
in
quant_cls
.
get_config_filenames
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment