Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f68e00d
Unverified
Commit
9f68e00d
authored
Sep 07, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 07, 2024
Browse files
[Bugfix] Fix broken OpenAI tensorizer test (#8258)
parent
ce2702a9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
81 additions
and
40 deletions
+81
-40
tests/utils.py
tests/utils.py
+6
-6
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+39
-33
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+29
-1
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+7
-0
No files found.
tests/utils.py
View file @
9f68e00d
...
...
@@ -20,7 +20,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.model_executor.model_loader.loader
import
DefaultM
odel
L
oader
from
vllm.model_executor.model_loader.loader
import
get_m
odel
_l
oader
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_port
,
is_hip
...
...
@@ -89,11 +89,11 @@ class RemoteOpenAIServer:
is_local
=
os
.
path
.
isdir
(
model
)
if
not
is_local
:
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
_config
=
engine_args
.
create_
engine
_config
()
dummy_loader
=
DefaultModelLoader
(
engine_config
.
load_config
)
dummy_loader
.
_prepare_weights
(
engine_config
.
model_config
.
model
,
engine_config
.
model_config
.
revision
,
fall_back_to_pt
=
True
)
model
_config
=
engine_args
.
create_
model
_config
()
load_config
=
engine_args
.
create_
load_config
(
)
model_loader
=
get_model_loader
(
load_config
)
model_loader
.
download_model
(
model_config
)
env
=
os
.
environ
.
copy
()
# the current process might initialize cuda,
...
...
vllm/engine/arg_utils.py
View file @
9f68e00d
...
...
@@ -771,33 +771,8 @@ class EngineArgs:
engine_args
=
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
engine_args
def
create_engine_config
(
self
)
->
EngineConfig
:
# gguf file needs a specific model loader and doesn't use hf_repo
if
check_gguf_file
(
self
.
model
):
self
.
quantization
=
self
.
load_format
=
"gguf"
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if
(
self
.
quantization
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes quantization and QLoRA adapter only support "
f
"'bitsandbytes' load format, but got
{
self
.
load_format
}
"
)
if
(
self
.
load_format
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
quantization
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes load format and QLoRA adapter only support "
f
"'bitsandbytes' quantization, but got
{
self
.
quantization
}
"
)
assert
self
.
cpu_offload_gb
>=
0
,
(
"CPU offload space must be non-negative"
f
", but got
{
self
.
cpu_offload_gb
}
"
)
device_config
=
DeviceConfig
(
device
=
self
.
device
)
model_config
=
ModelConfig
(
def
create_model_config
(
self
)
->
ModelConfig
:
return
ModelConfig
(
model
=
self
.
model
,
tokenizer
=
self
.
tokenizer
,
tokenizer_mode
=
self
.
tokenizer_mode
,
...
...
@@ -825,6 +800,42 @@ class EngineArgs:
config_format
=
self
.
config_format
,
)
def
create_load_config
(
self
)
->
LoadConfig
:
return
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
ignore_patterns
=
self
.
ignore_patterns
,
)
def
create_engine_config
(
self
)
->
EngineConfig
:
# gguf file needs a specific model loader and doesn't use hf_repo
if
check_gguf_file
(
self
.
model
):
self
.
quantization
=
self
.
load_format
=
"gguf"
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if
(
self
.
quantization
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes quantization and QLoRA adapter only support "
f
"'bitsandbytes' load format, but got
{
self
.
load_format
}
"
)
if
(
self
.
load_format
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
quantization
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes load format and QLoRA adapter only support "
f
"'bitsandbytes' quantization, but got
{
self
.
quantization
}
"
)
assert
self
.
cpu_offload_gb
>=
0
,
(
"CPU offload space must be non-negative"
f
", but got
{
self
.
cpu_offload_gb
}
"
)
device_config
=
DeviceConfig
(
device
=
self
.
device
)
model_config
=
self
.
create_model_config
()
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
self
.
max_model_len
,
# neuron needs block_size = max_model_len
...
...
@@ -967,12 +978,7 @@ class EngineArgs:
self
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
=
self
.
qlora_adapter_name_or_path
load_config
=
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
ignore_patterns
=
self
.
ignore_patterns
,
)
load_config
=
self
.
create_load_config
()
prompt_adapter_config
=
PromptAdapterConfig
(
max_prompt_adapters
=
self
.
max_prompt_adapters
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
9f68e00d
...
...
@@ -185,6 +185,11 @@ class BaseModelLoader(ABC):
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
@
abstractmethod
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download a model so that it can be immediately loaded."""
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
...
...
@@ -193,7 +198,7 @@ class BaseModelLoader(ABC):
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
...
raise
NotImplementedError
class
DefaultModelLoader
(
BaseModelLoader
):
...
...
@@ -335,6 +340,11 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
return
weights_iterator
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
True
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -377,6 +387,9 @@ class DummyModelLoader(BaseModelLoader):
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -467,6 +480,12 @@ class TensorizerLoader(BaseModelLoader):
model
=
load_with_tensorizer
(
tensorizer_config
,
**
extra_kwargs
)
return
model
.
eval
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
with
self
.
tensorizer_config
.
open_stream
():
pass
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -568,6 +587,9 @@ class ShardedStateLoader(BaseModelLoader):
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -995,6 +1017,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
@@ -1070,6 +1095,9 @@ class GGUFModelLoader(BaseModelLoader):
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
9f68e00d
...
...
@@ -99,6 +99,13 @@ class TensorizerConfig:
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors."
)
def
open_stream
(
self
,
tensorizer_args
:
Optional
[
"TensorizerArgs"
]
=
None
):
if
tensorizer_args
is
None
:
tensorizer_args
=
self
.
_construct_tensorizer_args
()
return
open_stream
(
self
.
tensorizer_uri
,
**
tensorizer_args
.
stream_params
)
def
load_with_tensorizer
(
tensorizer_config
:
TensorizerConfig
,
**
extra_kwargs
)
->
nn
.
Module
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment