Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9c0605a
Unverified
Commit
b9c0605a
authored
Jun 01, 2024
by
chenqianfzh
Committed by
GitHub
Jun 01, 2024
Browse files
[Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776)
parent
37464a0f
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
752 additions
and
8 deletions
+752
-8
examples/lora_with_quantization_inference.py
examples/lora_with_quantization_inference.py
+140
-0
requirements-dev.txt
requirements-dev.txt
+3
-0
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+80
-0
vllm/config.py
vllm/config.py
+8
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+35
-3
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+39
-2
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+175
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+246
-1
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+15
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+8
-0
No files found.
examples/lora_with_quantization_inference.py
0 → 100644
View file @
b9c0605a
"""
This example shows how to use LoRA with different quantization techniques
for offline inference.
Requires HuggingFace credentials for access.
"""
import
gc
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
huggingface_hub
import
snapshot_download
from
vllm
import
EngineArgs
,
LLMEngine
,
RequestOutput
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
def
create_test_prompts
(
lora_path
:
str
)
->
List
[
Tuple
[
str
,
SamplingParams
,
Optional
[
LoRARequest
]]]:
return
[
# this is an example of using quantization without LoRA
(
"My name is"
,
SamplingParams
(
temperature
=
0.0
,
logprobs
=
1
,
prompt_logprobs
=
1
,
max_tokens
=
128
),
None
),
# the next three examples use quantization with LoRA
(
"my name is"
,
SamplingParams
(
temperature
=
0.0
,
logprobs
=
1
,
prompt_logprobs
=
1
,
max_tokens
=
128
),
LoRARequest
(
"lora-test-1"
,
1
,
lora_path
)),
(
"The capital of USA is"
,
SamplingParams
(
temperature
=
0.0
,
logprobs
=
1
,
prompt_logprobs
=
1
,
max_tokens
=
128
),
LoRARequest
(
"lora-test-2"
,
1
,
lora_path
)),
(
"The capital of France is"
,
SamplingParams
(
temperature
=
0.0
,
logprobs
=
1
,
prompt_logprobs
=
1
,
max_tokens
=
128
),
LoRARequest
(
"lora-test-3"
,
1
,
lora_path
)),
]
def
process_requests
(
engine
:
LLMEngine
,
test_prompts
:
List
[
Tuple
[
str
,
SamplingParams
,
Optional
[
LoRARequest
]]]):
"""Continuously process a list of prompts and handle the outputs."""
request_id
=
0
while
test_prompts
or
engine
.
has_unfinished_requests
():
if
test_prompts
:
prompt
,
sampling_params
,
lora_request
=
test_prompts
.
pop
(
0
)
engine
.
add_request
(
str
(
request_id
),
prompt
,
sampling_params
,
lora_request
=
lora_request
)
request_id
+=
1
request_outputs
:
List
[
RequestOutput
]
=
engine
.
step
()
for
request_output
in
request_outputs
:
if
request_output
.
finished
:
print
(
"----------------------------------------------------"
)
print
(
f
"Prompt:
{
request_output
.
prompt
}
"
)
print
(
f
"Output:
{
request_output
.
outputs
[
0
].
text
}
"
)
def
initialize_engine
(
model
:
str
,
quantization
:
str
,
lora_repo
:
Optional
[
str
])
->
LLMEngine
:
"""Initialize the LLMEngine."""
if
quantization
==
"bitsandbytes"
:
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
# It quantizes the model when loading, with some config info from the
# LoRA adapter repo. So need to set the parameter of load_format and
# qlora_adapter_name_or_path as below.
engine_args
=
EngineArgs
(
model
=
model
,
quantization
=
quantization
,
qlora_adapter_name_or_path
=
lora_repo
,
load_format
=
"bitsandbytes"
,
enable_lora
=
True
,
max_lora_rank
=
64
,
# set it only in GPUs of limited memory
enforce_eager
=
True
)
else
:
engine_args
=
EngineArgs
(
model
=
model
,
quantization
=
quantization
,
enable_lora
=
True
,
max_loras
=
4
,
# set it only in GPUs of limited memory
enforce_eager
=
True
)
return
LLMEngine
.
from_engine_args
(
engine_args
)
def
main
():
"""Main function that sets up and runs the prompt processing."""
test_configs
=
[{
"name"
:
"qlora_inference_example"
,
'model'
:
"huggyllama/llama-7b"
,
'quantization'
:
"bitsandbytes"
,
'lora_repo'
:
'timdettmers/qlora-flan-7b'
},
{
"name"
:
"AWQ_inference_with_lora_example"
,
'model'
:
'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ'
,
'quantization'
:
"awq"
,
'lora_repo'
:
'jashing/tinyllama-colorist-lora'
},
{
"name"
:
"GPTQ_inference_with_lora_example"
,
'model'
:
'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ'
,
'quantization'
:
"gptq"
,
'lora_repo'
:
'jashing/tinyllama-colorist-lora'
}]
for
test_config
in
test_configs
:
print
(
f
"~~~~~~~~~~~~~~~~ Running:
{
test_config
[
'name'
]
}
~~~~~~~~~~~~~~~~"
)
engine
=
initialize_engine
(
test_config
[
'model'
],
test_config
[
'quantization'
],
test_config
[
'lora_repo'
])
lora_path
=
snapshot_download
(
repo_id
=
test_config
[
'lora_repo'
])
test_prompts
=
create_test_prompts
(
lora_path
)
process_requests
(
engine
,
test_prompts
)
# Clean up the GPU memory for the next test
del
engine
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
if
__name__
==
'__main__'
:
main
()
requirements-dev.txt
View file @
b9c0605a
...
...
@@ -35,3 +35,6 @@ aiohttp
# Multimodal
pillow
# quantization
bitsandbytes==0.42.0
tests/quantization/test_bitsandbytes.py
0 → 100644
View file @
b9c0605a
'''Tests whether bitsandbytes computation is enabled correctly.
Run `pytest tests/quantization/test_bitsandbytes.py`.
'''
import
pytest
import
torch
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
@
pytest
.
mark
.
skipif
(
capability
<
QUANTIZATION_METHODS
[
'bitsandbytes'
].
get_min_capability
(),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
def
test_load_bnb_model
(
vllm_runner
)
->
None
:
llm
=
vllm_runner
(
'huggyllama/llama-7b'
,
quantization
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
enforce_eager
=
True
)
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# check the weights in MLP & SelfAttention are quantized to torch.uint8
qweight
=
model
.
model
.
layers
[
0
].
mlp
.
gate_up_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected gate_up_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
qweight
=
model
.
model
.
layers
[
0
].
mlp
.
down_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected down_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
qweight
=
model
.
model
.
layers
[
0
].
self_attn
.
o_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected o_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
qweight
=
model
.
model
.
layers
[
0
].
self_attn
.
qkv_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected qkv_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
# some weights should not be quantized
weight
=
model
.
lm_head
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'lm_head weight dtype should not be torch.uint8'
)
weight
=
model
.
model
.
embed_tokens
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'embed_tokens weight dtype should not be torch.uint8'
)
weight
=
model
.
model
.
layers
[
0
].
input_layernorm
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'input_layernorm weight dtype should not be torch.uint8'
)
weight
=
model
.
model
.
layers
[
0
].
post_attention_layernorm
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'input_layernorm weight dtype should not be torch.uint8'
)
# check the output of the model is expected
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
logprobs
=
1
,
prompt_logprobs
=
1
,
max_tokens
=
8
)
prompts
=
[
'That which does not kill us'
,
'To be or not to be,'
]
expected_outputs
=
[
'That which does not kill us makes us stronger.'
,
'To be or not to be, that is the question.'
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
outputs
)
==
len
(
prompts
)
for
index
in
range
(
len
(
outputs
)):
# compare the first line of the output
actual_output
=
outputs
[
index
][
1
][
0
].
split
(
'
\n
'
,
1
)[
0
]
expected_output
=
expected_outputs
[
index
].
split
(
'
\n
'
,
1
)[
0
]
assert
actual_output
==
expected_output
,
(
f
'Expected:
{
expected_output
}
, but got:
{
actual_output
}
'
)
vllm/config.py
View file @
b9c0605a
...
...
@@ -241,6 +241,12 @@ class ModelConfig:
"must be divisible by pipeline parallel size "
f
"(
{
pipeline_parallel_size
}
)."
)
if
self
.
quantization
==
"bitsandbytes"
and
(
parallel_config
.
tensor_parallel_size
>
1
or
parallel_config
.
pipeline_parallel_size
>
1
):
raise
ValueError
(
"BitAndBytes quantization with TP or PP is not supported yet."
)
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
"""Get the sliding window size, or None if disabled.
"""
...
...
@@ -487,6 +493,7 @@ class LoadFormat(str, enum.Enum):
DUMMY
=
"dummy"
TENSORIZER
=
"tensorizer"
SHARDED_STATE
=
"sharded_state"
BITSANDBYTES
=
"bitsandbytes"
@
dataclass
...
...
vllm/engine/arg_utils.py
View file @
b9c0605a
...
...
@@ -92,6 +92,8 @@ class EngineArgs:
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
qlora_adapter_name_or_path
:
Optional
[
str
]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
self
.
tokenizer
=
self
.
model
...
...
@@ -159,7 +161,8 @@ class EngineArgs:
type
=
str
,
default
=
EngineArgs
.
load_format
,
choices
=
[
'auto'
,
'pt'
,
'safetensors'
,
'npcache'
,
'dummy'
,
'tensorizer'
'auto'
,
'pt'
,
'safetensors'
,
'npcache'
,
'dummy'
,
'tensorizer'
,
'bitsandbytes'
],
help
=
'The format of the model weights to load.
\n\n
'
'* "auto" will try to load the weights in the safetensors format '
...
...
@@ -173,7 +176,9 @@ class EngineArgs:
'which is mainly for profiling.
\n
'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.
\n
'
)
'section for more information.
\n
'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.
\n
'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
...
...
@@ -543,7 +548,10 @@ class EngineArgs:
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one."
)
parser
.
add_argument
(
'--qlora-adapter-name-or-path'
,
type
=
str
,
default
=
None
,
help
=
'Name or path of the QLoRA adapter.'
)
return
parser
@
classmethod
...
...
@@ -555,6 +563,23 @@ class EngineArgs:
return
engine_args
def
create_engine_config
(
self
,
)
->
EngineConfig
:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if
(
self
.
quantization
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes quantization and QLoRA adapter only support "
f
"'bitsandbytes' load format, but got
{
self
.
load_format
}
"
)
if
(
self
.
load_format
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
quantization
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes load format and QLoRA adapter only support "
f
"'bitsandbytes' quantization, but got
{
self
.
quantization
}
"
)
device_config
=
DeviceConfig
(
self
.
device
)
model_config
=
ModelConfig
(
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer_mode
,
...
...
@@ -622,6 +647,13 @@ class EngineArgs:
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
and
self
.
max_cpu_loras
>
0
else
None
)
if
self
.
enable_lora
else
None
if
self
.
qlora_adapter_name_or_path
is
not
None
and
\
self
.
qlora_adapter_name_or_path
!=
""
:
if
self
.
model_loader_extra_config
is
None
:
self
.
model_loader_extra_config
=
{}
self
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
=
self
.
qlora_adapter_name_or_path
load_config
=
LoadConfig
(
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
...
...
vllm/model_executor/layers/linear.py
View file @
b9c0605a
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
...
...
@@ -26,6 +26,21 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_bitsandbytes_shard
(
param
:
Parameter
,
qkv_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total
,
_
=
qkv_offsets
[
"total"
]
orig_offset
,
orig_size
=
qkv_offsets
[
loaded_shard_id
]
quantized_total
=
param
.
data
.
shape
[
0
]
quantized_offset
=
orig_offset
*
quantized_total
//
total
quantized_size
=
orig_size
*
quantized_total
//
total
return
quantized_size
,
quantized_offset
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
...
...
@@ -416,6 +431,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
...
...
@@ -615,6 +636,22 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
num_heads
*
self
.
head_size
),
"k"
:
(
self
.
num_heads
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
),
"v"
:
((
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
,
self
.
num_kv_heads
*
self
.
head_size
),
"total"
:
((
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
,
0
)
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
b9c0605a
...
...
@@ -4,6 +4,8 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.bitsandbytes
import
(
BitsAndBytesConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
...
...
@@ -30,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"sparseml"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
}
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
0 → 100644
View file @
b9c0605a
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
class
BitsAndBytesConfig
(
QuantizationConfig
):
"""Config class for BitsAndBytes Quantization.
Reference: https://arxiv.org/abs/2305.14314
"""
def
__init__
(
self
,
adapter_name_or_path
:
str
,
target_modules
:
List
[
str
],
)
->
None
:
self
.
adapter_name_or_path
=
adapter_name_or_path
self
.
target_modules
=
target_modules
def
__repr__
(
self
)
->
str
:
return
(
f
"BitsAndBytesConfig(adapter_name_or_path=
{
self
.
adapter_name_or_path
}
"
)
@
classmethod
def
get_name
(
self
)
->
str
:
return
"bitsandbytes"
@
classmethod
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
self
)
->
int
:
return
70
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"adapter_config.json"
,
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BitsAndBytesConfig"
:
adapter_name
=
cls
.
get_from_keys
(
config
,
[
"adapter_name_or_path"
])
default_target_modules
=
[
"gate_proj"
,
"down_proj"
,
"up_proj"
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
"o_proj"
]
if
adapter_name
==
""
:
target_modules
=
default_target_modules
else
:
target_modules
=
cls
.
get_from_keys
(
config
,
[
"target_modules"
])
return
cls
(
adapter_name
,
target_modules
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
BitsAndBytesLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
"""Linear method for BitsAndBytes.
Args:
quant_config: The BitsAndBytes quantization config.
"""
def
__init__
(
self
,
quant_config
:
BitsAndBytesConfig
):
try
:
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
"bitsandbytes quantizer."
)
from
err
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
quant_ratio
=
0
if
params_dtype
.
is_floating_point
:
quant_ratio
=
torch
.
finfo
(
params_dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
else
:
quant_ratio
=
torch
.
iinfo
(
params_dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
if
input_size_per_partition
*
sum
(
output_partition_sizes
)
%
quant_ratio
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. "
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
*
sum
(
output_partition_sizes
)
//
quant_ratio
,
1
,
dtype
=
torch
.
uint8
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
# In bitsandbytes, a tensor of shape [n,m] is quantized to
#[n*m/pack_ratio, 1],so the output_dim is 0
"output_dim"
:
0
,
"pack_factor"
:
quant_ratio
,
"use_bitsandbytes"
:
True
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
matmul_4bit
original_type
=
x
.
dtype
bf_x
=
x
.
to
(
torch
.
bfloat16
)
qweight
=
layer
.
qweight
quant_states
=
qweight
.
bnb_quant_state
offsets
=
qweight
.
bnb_shard_offsets
out_dim_0
=
x
.
shape
[
0
]
out_dim_1
=
sum
(
[
quant_state
[
1
].
shape
[
0
]
for
quant_state
in
quant_states
.
items
()])
out
=
torch
.
empty
(
out_dim_0
,
out_dim_1
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
current_index
=
0
for
i
in
range
(
len
(
quant_states
)):
output_size
=
quant_states
[
i
].
shape
[
0
]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out
[:,
current_index
:
current_index
+
output_size
]
=
matmul_4bit
(
bf_x
,
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]].
t
(),
quant_states
[
i
])
current_index
+=
output_size
out
=
out
.
to
(
original_type
)
if
bias
is
not
None
:
out
+=
bias
return
out
vllm/model_executor/model_loader/loader.py
View file @
b9c0605a
# ruff: noqa: SIM117
import
collections
import
copy
import
fnmatch
import
glob
import
json
import
math
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Type
import
huggingface_hub
import
numpy
as
np
import
torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoadFormat
,
...
...
@@ -28,6 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import (
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
...
...
@@ -247,6 +253,7 @@ class DefaultModelLoader(BaseModelLoader):
model
,
"fall_back_to_pt_during_load"
,
True
)),
)
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
...
...
@@ -539,6 +546,241 @@ class ShardedStateLoader(BaseModelLoader):
)
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
"""Model loader to load model weights with BitAndBytes quantization."""
default_target_modules
=
[
"gate_proj"
,
"down_proj"
,
"up_proj"
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
"o_proj"
]
possible_config_file_names
=
[
"adapter_config.json"
]
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
if
(
not
load_config
.
model_loader_extra_config
or
"qlora_adapter_name_or_path"
not
in
load_config
.
model_loader_extra_config
):
self
.
target_modules
=
self
.
default_target_modules
return
qlora_adapter
=
load_config
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
config_file_path
=
self
.
_get_config_file
(
qlora_adapter
)
with
open
(
config_file_path
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
self
.
target_modules
=
config
[
"target_modules"
]
def
_get_config_file
(
self
,
qlora_adapter
:
str
)
->
str
:
is_local
=
os
.
path
.
isdir
(
qlora_adapter
)
config_file_path
=
None
if
is_local
:
for
file
in
self
.
possible_config_file_names
:
config_file_path
=
os
.
path
.
join
(
qlora_adapter
,
file
)
if
os
.
path
.
exists
(
config_file_path
):
break
else
:
hf_api
=
HfApi
()
repo_files
=
hf_api
.
list_repo_files
(
repo_id
=
qlora_adapter
)
for
file
in
self
.
possible_config_file_names
:
if
file
in
repo_files
:
config_file_path
=
hf_hub_download
(
repo_id
=
qlora_adapter
,
filename
=
file
)
break
if
not
config_file_path
:
raise
ValueError
(
f
"Cannot find adapter config file in
{
qlora_adapter
}
"
)
return
config_file_path
def
_get_weight_files
(
self
,
model_name_or_path
:
str
,
allowed_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
)
->
Tuple
[
List
[
str
],
str
]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
is_local
:
for
pattern
in
allowed_patterns
:
weight_files
=
glob
.
glob
(
os
.
path
.
join
(
model_name_or_path
,
pattern
))
if
weight_files
:
return
weight_files
,
pattern
else
:
hf_api
=
HfApi
()
repo_files
=
hf_api
.
list_repo_files
(
repo_id
=
model_name_or_path
)
for
pattern
in
allowed_patterns
:
matching_files
=
fnmatch
.
filter
(
repo_files
,
pattern
)
if
matching_files
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
pattern
],
revision
)
return
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
raise
RuntimeError
(
f
"No model weights found in: `
{
model_name_or_path
}
`"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
])
->
Tuple
[
List
[
str
],
bool
]:
"""Prepare weight files for the model."""
allowed_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.pt"
]
hf_weights_files
,
matched_pattern
=
self
.
_get_weight_files
(
model_name_or_path
,
allowed_patterns
,
revision
)
if
matched_pattern
!=
"*.safetensors"
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try
:
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
from
bitsandbytes.functional
import
quantize_4bit
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
"bitsandbytes quantizer."
)
from
err
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
)
quant_state_dict
=
{}
if
use_safetensors
:
weight_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weight_iterator
=
pt_weights_iterator
(
hf_weights_files
)
def
generator
():
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
compress_statistics
=
True
,
quant_type
=
"nf4"
)
quant_state_dict
[
weight_name
]
=
quant_state
else
:
processed_weight
=
weight_tensor
yield
weight_name
,
processed_weight
return
generator
(),
quant_state_dict
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
)
->
None
:
if
not
hasattr
(
model
,
'load_weights'
):
raise
AttributeError
(
"The required method 'load_weights' is not defined in class"
f
"
{
type
(
self
).
__name__
}
."
)
if
not
hasattr
(
model
,
'bitsandbytes_stacked_params_mapping'
):
raise
AttributeError
(
f
"Model
{
type
(
self
).
__name__
}
does not support BitsAndBytes "
"quantization yet."
)
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
))
model
.
load_weights
(
qweight_iterator
)
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
for
quant_param_name
in
quant_state_dict
:
non_stacked_param_name
=
quant_param_name
shard_index
=
0
for
shard_name
,
(
weight_name
,
index
)
in
model
.
bitsandbytes_stacked_params_mapping
.
items
():
if
shard_name
in
quant_param_name
:
shard_index
=
index
quant_param_name
=
quant_param_name
.
replace
(
shard_name
,
weight_name
)
break
if
quant_param_name
not
in
param_dict
:
raise
ValueError
(
f
"Parameter
{
quant_param_name
}
not found in the model."
)
if
quant_param_name
not
in
stacked_quant_state_dict
:
stacked_quant_state_dict
[
quant_param_name
]
=
{}
stacked_quant_state_dict
[
quant_param_name
][
shard_index
]
=
(
quant_state_dict
[
non_stacked_param_name
])
# save quant_states and offsets as the attributes of the parameters
for
param_name
,
param
in
param_dict
.
items
():
if
param_name
in
stacked_quant_state_dict
:
quant_states
=
stacked_quant_state_dict
[
param_name
]
set_weight_attrs
(
param
,
{
"bnb_quant_state"
:
quant_states
})
pack_ratio
=
getattr
(
param
,
"pack_factor"
,
-
1
)
if
pack_ratio
==
-
1
:
raise
ValueError
(
f
"pack_factor not set for parameter
{
param_name
}
."
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
enumerate
(
quant_states
.
items
()):
num_elements
[
seq
]
=
math
.
prod
(
quant_state
[
1
].
shape
)
//
pack_ratio
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
vision_language_config
:
Optional
[
VisionLanguageConfig
],
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
model_config
,
self
.
load_config
,
lora_config
,
vision_language_config
,
cache_config
)
self
.
_load_weights
(
model_config
,
model
)
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
...
...
@@ -554,4 +796,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
return
DefaultModelLoader
(
load_config
)
vllm/model_executor/model_loader/weight_utils.py
View file @
b9c0605a
...
...
@@ -130,6 +130,16 @@ def get_quant_config(model_config: ModelConfig,
if
hf_quant_config
is
not
None
:
return
quant_cls
.
from_config
(
hf_quant_config
)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if
model_config
.
quantization
==
"bitsandbytes"
:
if
(
not
load_config
.
model_loader_extra_config
or
"qlora_adapter_name_or_path"
not
in
load_config
.
model_loader_extra_config
):
return
quant_cls
.
from_config
({
"adapter_name_or_path"
:
""
})
model_name_or_path
=
load_config
.
model_loader_extra_config
[
"qlora_adapter_name_or_path"
]
else
:
model_name_or_path
=
model_config
.
model
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
not
is_local
:
...
...
@@ -169,6 +179,10 @@ def get_quant_config(model_config: ModelConfig,
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
return
quant_cls
.
from_config
(
config
)
...
...
vllm/model_executor/models/llama.py
View file @
b9c0605a
...
...
@@ -319,6 +319,14 @@ class LlamaForCausalLM(nn.Module):
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment