Unverified Commit ba7703e6 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc] Remove qlora_adapter_name_or_path (#17699)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent f80ae5bd
......@@ -75,43 +75,38 @@ def initialize_engine(model: str, quantization: str,
lora_repo: Optional[str]) -> LLMEngine:
"""Initialize the LLMEngine."""
if quantization == "bitsandbytes":
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
# It quantizes the model when loading, with some config info from the
# LoRA adapter repo. So need to set the parameter of load_format and
# qlora_adapter_name_or_path as below.
engine_args = EngineArgs(model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
enable_lora=True,
max_lora_rank=64)
else:
engine_args = EngineArgs(model=model,
quantization=quantization,
enable_lora=True,
max_loras=4)
engine_args = EngineArgs(model=model,
quantization=quantization,
enable_lora=True,
max_lora_rank=64,
max_loras=4)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
test_configs = [{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
}, {
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}, {
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}]
test_configs = [
# QLoRA (https://arxiv.org/abs/2305.14314)
{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
},
{
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
},
{
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}
]
for test_config in test_configs:
print(
......
......@@ -6,6 +6,7 @@ import dataclasses
import json
import re
import threading
import warnings
from dataclasses import MISSING, dataclass, fields
from itertools import permutations
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
......@@ -394,7 +395,13 @@ class EngineArgs:
if isinstance(self.compilation_config, (int, dict)):
self.compilation_config = CompilationConfig.from_cli(
str(self.compilation_config))
if self.qlora_adapter_name_or_path is not None:
warnings.warn(
"The `qlora_adapter_name_or_path` is deprecated "
"and will be removed in v0.10.0. ",
DeprecationWarning,
stacklevel=2,
)
# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
......@@ -504,10 +511,14 @@ class EngineArgs:
**load_kwargs["ignore_patterns"])
load_group.add_argument("--use-tqdm-on-load",
**load_kwargs["use_tqdm_on_load"])
load_group.add_argument('--qlora-adapter-name-or-path',
type=str,
default=None,
help='Name or path of the QLoRA adapter.')
load_group.add_argument(
"--qlora-adapter-name-or-path",
type=str,
default=None,
help="The `--qlora-adapter-name-or-path` has no effect, do not set"
" it, and it will be removed in v0.10.0.",
deprecated=True,
)
load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"])
......@@ -534,7 +545,7 @@ class EngineArgs:
deprecated=True,
help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as "
"of v0.8.6. Use `--reasoning-parser` to specify the reasoning "
"parser backend insteadThis flag (`--enable-reasoning`) will be "
"parser backend instead. This flag (`--enable-reasoning`) will be "
"removed in v0.10.0. When `--reasoning-parser` is specified, "
"reasoning mode is automatically enabled.")
guided_decoding_group.add_argument(
......@@ -896,12 +907,6 @@ class EngineArgs:
def create_load_config(self) -> LoadConfig:
if(self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
......@@ -1098,11 +1103,6 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
# bitsandbytes pre-quantized model need a specific model loader
if model_config.quantization == "bitsandbytes":
self.quantization = self.load_format = "bitsandbytes"
......
......@@ -162,23 +162,15 @@ def get_quant_config(model_config: ModelConfig,
None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
# Inflight BNB quantization
if model_config.quantization == "bitsandbytes":
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
return quant_cls.from_config({"adapter_name_or_path": ""})
model_name_or_path = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]
else:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
return quant_cls.from_config({})
is_local = os.path.isdir(model_config.model)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
with get_lock(model_config.model, load_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
model_config.model,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
......@@ -186,7 +178,7 @@ def get_quant_config(model_config: ModelConfig,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
hf_folder = model_config.model
possible_config_filenames = quant_cls.get_config_filenames()
......@@ -213,7 +205,7 @@ def get_quant_config(model_config: ModelConfig,
config = json.load(f)
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
config["adapter_name_or_path"] = model_config.model
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
return quant_cls.from_config(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment