Unverified Commit ba7703e6 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Remove qlora_adapter_name_or_path (#17699)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent f80ae5bd
...@@ -75,20 +75,10 @@ def initialize_engine(model: str, quantization: str, ...@@ -75,20 +75,10 @@ def initialize_engine(model: str, quantization: str,
lora_repo: Optional[str]) -> LLMEngine: lora_repo: Optional[str]) -> LLMEngine:
"""Initialize the LLMEngine.""" """Initialize the LLMEngine."""
if quantization == "bitsandbytes":
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
# It quantizes the model when loading, with some config info from the
# LoRA adapter repo. So need to set the parameter of load_format and
# qlora_adapter_name_or_path as below.
engine_args = EngineArgs(model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
enable_lora=True,
max_lora_rank=64)
else:
engine_args = EngineArgs(model=model, engine_args = EngineArgs(model=model,
quantization=quantization, quantization=quantization,
enable_lora=True, enable_lora=True,
max_lora_rank=64,
max_loras=4) max_loras=4)
return LLMEngine.from_engine_args(engine_args) return LLMEngine.from_engine_args(engine_args)
...@@ -96,22 +86,27 @@ def initialize_engine(model: str, quantization: str, ...@@ -96,22 +86,27 @@ def initialize_engine(model: str, quantization: str,
def main(): def main():
"""Main function that sets up and runs the prompt processing.""" """Main function that sets up and runs the prompt processing."""
test_configs = [{ test_configs = [
# QLoRA (https://arxiv.org/abs/2305.14314)
{
"name": "qlora_inference_example", "name": "qlora_inference_example",
'model': "huggyllama/llama-7b", 'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes", 'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b' 'lora_repo': 'timdettmers/qlora-flan-7b'
}, { },
{
"name": "AWQ_inference_with_lora_example", "name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq", 'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora' 'lora_repo': 'jashing/tinyllama-colorist-lora'
}, { },
{
"name": "GPTQ_inference_with_lora_example", "name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq", 'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora' 'lora_repo': 'jashing/tinyllama-colorist-lora'
}] }
]
for test_config in test_configs: for test_config in test_configs:
print( print(
......
...@@ -6,6 +6,7 @@ import dataclasses ...@@ -6,6 +6,7 @@ import dataclasses
import json import json
import re import re
import threading import threading
import warnings
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from itertools import permutations from itertools import permutations
from typing import (Any, Callable, Dict, List, Literal, Optional, Type, from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
...@@ -394,7 +395,13 @@ class EngineArgs: ...@@ -394,7 +395,13 @@ class EngineArgs:
if isinstance(self.compilation_config, (int, dict)): if isinstance(self.compilation_config, (int, dict)):
self.compilation_config = CompilationConfig.from_cli( self.compilation_config = CompilationConfig.from_cli(
str(self.compilation_config)) str(self.compilation_config))
if self.qlora_adapter_name_or_path is not None:
warnings.warn(
"The `qlora_adapter_name_or_path` is deprecated "
"and will be removed in v0.10.0. ",
DeprecationWarning,
stacklevel=2,
)
# Setup plugins # Setup plugins
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
load_general_plugins() load_general_plugins()
...@@ -504,10 +511,14 @@ class EngineArgs: ...@@ -504,10 +511,14 @@ class EngineArgs:
**load_kwargs["ignore_patterns"]) **load_kwargs["ignore_patterns"])
load_group.add_argument("--use-tqdm-on-load", load_group.add_argument("--use-tqdm-on-load",
**load_kwargs["use_tqdm_on_load"]) **load_kwargs["use_tqdm_on_load"])
load_group.add_argument('--qlora-adapter-name-or-path', load_group.add_argument(
"--qlora-adapter-name-or-path",
type=str, type=str,
default=None, default=None,
help='Name or path of the QLoRA adapter.') help="The `--qlora-adapter-name-or-path` has no effect, do not set"
" it, and it will be removed in v0.10.0.",
deprecated=True,
)
load_group.add_argument('--pt-load-map-location', load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"]) **load_kwargs["pt_load_map_location"])
...@@ -534,7 +545,7 @@ class EngineArgs: ...@@ -534,7 +545,7 @@ class EngineArgs:
deprecated=True, deprecated=True,
help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
"of v0.8.6. Use `--reasoning-parser` to specify the reasoning " "of v0.8.6. Use `--reasoning-parser` to specify the reasoning "
"parser backend insteadThis flag (`--enable-reasoning`) will be " "parser backend instead. This flag (`--enable-reasoning`) will be "
"removed in v0.10.0. When `--reasoning-parser` is specified, " "removed in v0.10.0. When `--reasoning-parser` is specified, "
"reasoning mode is automatically enabled.") "reasoning mode is automatically enabled.")
guided_decoding_group.add_argument( guided_decoding_group.add_argument(
...@@ -896,12 +907,6 @@ class EngineArgs: ...@@ -896,12 +907,6 @@ class EngineArgs:
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
if(self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
if self.quantization == "bitsandbytes": if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes" self.load_format = "bitsandbytes"
...@@ -1098,11 +1103,6 @@ class EngineArgs: ...@@ -1098,11 +1103,6 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None) if self.enable_lora else None
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
# bitsandbytes pre-quantized model need a specific model loader # bitsandbytes pre-quantized model need a specific model loader
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes" self.quantization = self.load_format = "bitsandbytes"
......
...@@ -162,23 +162,15 @@ def get_quant_config(model_config: ModelConfig, ...@@ -162,23 +162,15 @@ def get_quant_config(model_config: ModelConfig,
None) None)
if hf_quant_config is not None: if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config) return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model. # Inflight BNB quantization
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
if (not load_config.model_loader_extra_config return quant_cls.from_config({})
or "qlora_adapter_name_or_path" is_local = os.path.isdir(model_config.model)
not in load_config.model_loader_extra_config):
return quant_cls.from_config({"adapter_name_or_path": ""})
model_name_or_path = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]
else:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local: if not is_local:
# Download the config files. # Download the config files.
with get_lock(model_name_or_path, load_config.download_dir): with get_lock(model_config.model, load_config.download_dir):
hf_folder = snapshot_download( hf_folder = snapshot_download(
model_name_or_path, model_config.model,
revision=model_config.revision, revision=model_config.revision,
allow_patterns="*.json", allow_patterns="*.json",
cache_dir=load_config.download_dir, cache_dir=load_config.download_dir,
...@@ -186,7 +178,7 @@ def get_quant_config(model_config: ModelConfig, ...@@ -186,7 +178,7 @@ def get_quant_config(model_config: ModelConfig,
tqdm_class=DisabledTqdm, tqdm_class=DisabledTqdm,
) )
else: else:
hf_folder = model_name_or_path hf_folder = model_config.model
possible_config_filenames = quant_cls.get_config_filenames() possible_config_filenames = quant_cls.get_config_filenames()
...@@ -213,7 +205,7 @@ def get_quant_config(model_config: ModelConfig, ...@@ -213,7 +205,7 @@ def get_quant_config(model_config: ModelConfig,
config = json.load(f) config = json.load(f)
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path config["adapter_name_or_path"] = model_config.model
elif model_config.quantization == "modelopt": elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt": if config["producer"]["name"] == "modelopt":
return quant_cls.from_config(config) return quant_cls.from_config(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment