Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d72bb20
Unverified
Commit
8d72bb20
authored
Nov 04, 2024
by
youkaichao
Committed by
GitHub
Nov 04, 2024
Browse files
[4/N] make quant config first-class citizen (#9978)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
ac6b8f19
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
31 deletions
+41
-31
vllm/config.py
vllm/config.py
+38
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+3
-31
No files found.
vllm/config.py
View file @
8d72bb20
...
@@ -23,9 +23,13 @@ if TYPE_CHECKING:
...
@@ -23,9 +23,13 @@ if TYPE_CHECKING:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.model_loader.loader
import
BaseModelLoader
from
vllm.model_executor.model_loader.loader
import
BaseModelLoader
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
from
vllm.transformers_utils.tokenizer_group.base_tokenizer_group
import
(
BaseTokenizerGroup
)
BaseTokenizerGroup
)
else
:
QuantizationConfig
=
None
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -1966,6 +1970,35 @@ class VllmConfig:
...
@@ -1966,6 +1970,35 @@ class VllmConfig:
decoding_config
:
Optional
[
DecodingConfig
]
=
None
decoding_config
:
Optional
[
DecodingConfig
]
=
None
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
quant_config
:
Optional
[
QuantizationConfig
]
=
None
@
staticmethod
def
_get_quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
from
vllm.model_executor.model_loader.weight_utils
import
(
get_quant_config
)
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability_tuple
=
current_platform
.
get_device_capability
()
if
capability_tuple
is
not
None
:
capability
=
capability_tuple
.
to_int
()
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
"
"is not supported for the current GPU. Minimum "
f
"capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
return
quant_config
return
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""Verify configs are valid & consistent with each other.
...
@@ -1983,3 +2016,8 @@ class VllmConfig:
...
@@ -1983,3 +2016,8 @@ class VllmConfig:
if
self
.
prompt_adapter_config
:
if
self
.
prompt_adapter_config
:
self
.
prompt_adapter_config
.
verify_with_model_config
(
self
.
prompt_adapter_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
model_config
)
if
self
.
quant_config
is
None
and
\
self
.
model_config
is
not
None
and
self
.
load_config
is
not
None
:
self
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
self
.
model_config
,
self
.
load_config
)
vllm/model_executor/model_loader/loader.py
View file @
8d72bb20
...
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
...
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_gguf_extra_tensor_names
,
get_quant_config
,
gguf_quant_weights_iterator
,
get_gguf_extra_tensor_names
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
safetensors_weights_iterator
)
from
vllm.model_executor.models
import
(
has_inner_state
,
supports_lora
,
from
vllm.model_executor.models
import
(
has_inner_state
,
supports_lora
,
...
@@ -93,32 +93,6 @@ def device_loading_context(module: torch.nn.Module,
...
@@ -93,32 +93,6 @@ def device_loading_context(module: torch.nn.Module,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
def
_get_quantization_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
)
capability_tuple
=
current_platform
.
get_device_capability
()
if
capability_tuple
is
not
None
:
capability
=
capability_tuple
.
to_int
()
if
capability
<
quant_config
.
get_min_capability
():
raise
ValueError
(
f
"The quantization method
{
model_config
.
quantization
}
"
"is not supported for the current GPU. "
f
"Minimum capability:
{
quant_config
.
get_min_capability
()
}
. "
f
"Current capability:
{
capability
}
."
)
supported_dtypes
=
quant_config
.
get_supported_act_dtypes
()
if
model_config
.
dtype
not
in
supported_dtypes
:
raise
ValueError
(
f
"
{
model_config
.
dtype
}
is not supported for quantization "
f
"method
{
model_config
.
quantization
}
. Supported dtypes: "
f
"
{
supported_dtypes
}
"
)
return
quant_config
return
None
def
_get_model_initialization_kwargs
(
def
_get_model_initialization_kwargs
(
model_class
:
Type
[
nn
.
Module
],
model_class
:
Type
[
nn
.
Module
],
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
...
@@ -185,7 +159,6 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
...
@@ -185,7 +159,6 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
lora_config
=
vllm_config
.
lora_config
lora_config
=
vllm_config
.
lora_config
scheduler_config
=
vllm_config
.
scheduler_config
scheduler_config
=
vllm_config
.
scheduler_config
cache_config
=
vllm_config
.
cache_config
cache_config
=
vllm_config
.
cache_config
load_config
=
vllm_config
.
load_config
model_class
,
_
=
get_model_architecture
(
model_config
)
model_class
,
_
=
get_model_architecture
(
model_config
)
return
build_model
(
return
build_model
(
...
@@ -193,7 +166,7 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
...
@@ -193,7 +166,7 @@ def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
vllm_config
,
vllm_config
,
model_config
.
hf_config
,
model_config
.
hf_config
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
_get_quantization_config
(
model
_config
,
load
_config
)
,
quant_config
=
vllm
_config
.
quant
_config
,
lora_config
=
lora_config
,
lora_config
=
lora_config
,
multimodal_config
=
model_config
.
multimodal_config
,
multimodal_config
=
model_config
.
multimodal_config
,
scheduler_config
=
scheduler_config
,
scheduler_config
=
scheduler_config
,
...
@@ -518,8 +491,7 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -518,8 +491,7 @@ class TensorizerLoader(BaseModelLoader):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
model_class
=
get_model_architecture
(
model_config
)[
0
]
quant_config
=
_get_quantization_config
(
quant_config
=
vllm_config
.
quant_config
model_config
,
self
.
load_config
)
extra_kwargs
=
_get_model_initialization_kwargs
(
extra_kwargs
=
_get_model_initialization_kwargs
(
model_class
,
lora_config
,
model_config
.
multimodal_config
)
model_class
,
lora_config
,
model_config
.
multimodal_config
)
extra_kwargs
[
"quant_config"
]
=
quant_config
extra_kwargs
[
"quant_config"
]
=
quant_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment