Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87525fab
Unverified
Commit
87525fab
authored
Jul 23, 2024
by
dongmao zhang
Committed by
GitHub
Jul 23, 2024
Browse files
[bitsandbytes]: support read bnb pre-quantized model (#5753)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
2f808e69
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
143 additions
and
39 deletions
+143
-39
docs/source/index.rst
docs/source/index.rst
+1
-0
docs/source/quantization/bnb.rst
docs/source/quantization/bnb.rst
+43
-0
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+14
-4
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-2
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+4
-21
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+76
-12
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+1
-0
No files found.
docs/source/index.rst
View file @
87525fab
...
...
@@ -105,6 +105,7 @@ Documentation
quantization/supported_hardware
quantization/auto_awq
quantization/bnb
quantization/fp8
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache
...
...
docs/source/quantization/bnb.rst
0 → 100644
View file @
87525fab
.. _bits_and_bytes:
BitsAndBytes
==================
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
Below are the steps to utilize BitsAndBytes with vLLM.
.. code-block:: console
$ pip install bitsandbytes>=0.42.0
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
And usually, these repositories have a config.json file that includes a quantization_config section.
Read quantized checkpoint.
--------------------------
.. code-block:: python
from vllm import LLM
import torch
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
model_id = "unsloth/tinyllama-bnb-4bit"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
Inflight quantization: load as 4bit quantization
------------------------------------------------
.. code-block:: python
from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
tests/quantization/test_bitsandbytes.py
View file @
87525fab
...
...
@@ -8,15 +8,20 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
SamplingParams
models_to_test
=
[
(
'huggyllama/llama-7b'
,
'quantize model inflight'
),
(
'lllyasviel/omost-llama-3-8b-4bits'
,
'read pre-quantized model'
),
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
def
test_load_bnb_model
(
vllm_runner
)
->
None
:
with
vllm_runner
(
'huggyllama/llama-7b'
,
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_to_test
)
def
test_load_bnb_model
(
vllm_runner
,
model_name
,
description
)
->
None
:
with
vllm_runner
(
model_name
,
quantization
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
# check the weights in MLP & SelfAttention are quantized to torch.uint8
...
...
@@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
'To be or not to be, that is the question.'
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
outputs
)
==
len
(
prompts
)
for
index
in
range
(
len
(
outputs
)):
# compare the first line of the output
actual_output
=
outputs
[
index
][
1
][
0
].
split
(
'
\n
'
,
1
)[
0
]
expected_output
=
expected_outputs
[
index
].
split
(
'
\n
'
,
1
)[
0
]
assert
len
(
actual_output
)
>=
len
(
expected_output
),
(
f
'Actual
{
actual_output
}
should be larger than or equal to '
f
'expected
{
expected_output
}
'
)
actual_output
=
actual_output
[:
len
(
expected_output
)]
assert
actual_output
==
expected_output
,
(
f
'Expected:
{
expected_output
}
, but got:
{
actual_output
}
'
)
vllm/config.py
View file @
87525fab
...
...
@@ -591,9 +591,11 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
...
...
vllm/engine/arg_utils.py
View file @
87525fab
...
...
@@ -676,8 +676,8 @@ class EngineArgs:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if
(
self
.
quantization
==
"bitsandbytes"
or
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
self
.
qlora_adapter_name_or_path
is
not
None
)
and
\
self
.
load_format
!=
"bitsandbytes"
:
raise
ValueError
(
"BitsAndBytes quantization and QLoRA adapter only support "
f
"'bitsandbytes' load format, but got
{
self
.
load_format
}
"
)
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
87525fab
...
...
@@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314
"""
def
__init__
(
self
,
adapter_name_or_path
:
str
,
target_modules
:
List
[
str
],
)
->
None
:
self
.
adapter_name_or_path
=
adapter_name_or_path
self
.
target_modules
=
target_modules
def
__init__
(
self
,
)
->
None
:
pass
def
__repr__
(
self
)
->
str
:
return
(
f
"BitsAndBytesConfig(adapter_name_or_path=
{
self
.
adapter_name_or_path
}
"
)
return
"BitsAndBytesConfig"
@
classmethod
def
get_name
(
self
)
->
str
:
...
...
@@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig):
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BitsAndBytesConfig"
:
adapter_name
=
cls
.
get_from_keys
(
config
,
[
"adapter_name_or_path"
])
default_target_modules
=
[
"gate_proj"
,
"down_proj"
,
"up_proj"
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
"o_proj"
]
if
adapter_name
==
""
:
target_modules
=
default_target_modules
else
:
target_modules
=
cls
.
get_from_keys
(
config
,
[
"target_modules"
])
return
cls
(
adapter_name
,
target_modules
)
return
cls
()
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
...
...
vllm/model_executor/model_loader/loader.py
View file @
87525fab
...
...
@@ -702,8 +702,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
else
:
return
pt_weights_iterator
(
hf_weights_files
)
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
,
pre_quant
:
bool
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
...
...
@@ -712,6 +718,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
try
:
import
bitsandbytes
from
bitsandbytes.functional
import
QuantState
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
...
...
@@ -725,17 +732,63 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path
,
revision
)
quant_state_dict
=
{}
if
use_safetensors
:
weight_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weight_iterator
=
pt_weights_iterator
(
hf_weights_files
)
def
generator
():
def
quantized_checkpoint
()
->
Generator
:
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# TODO: only nf4 quantization is supported for now
if
weight_name
.
endswith
(
".quant_state.bitsandbytes__fp4"
):
raise
NotImplementedError
(
"Only bitsandbytes_nf4 quantization"
f
"is supported for now.
{
weight_name
}
is fp4 quantized"
)
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
]
=
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
].
cpu
().
data
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
weight_name
+
".quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
:
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
generator
()
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
#
bitsandbytes requires data in GPU
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
...
...
@@ -749,6 +802,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield
weight_name
,
processed_weight
if
pre_quant
:
return
quantized_checkpoint
(),
quant_state_dict
return
generator
(),
quant_state_dict
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
...
...
@@ -766,12 +821,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
))
is_quantized_checkpoint
=
False
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
quant_config
is
not
None
and
quant_config
.
get
(
'quant_method'
)
==
"bitsandbytes"
:
is_quantized_checkpoint
=
True
qweight_iterator
,
quant_state_dict
=
\
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
is_quantized_checkpoint
)
model
.
load_weights
(
qweight_iterator
)
torch
.
cuda
.
empty_cache
()
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
for
quant_param_name
in
quant_state_dict
:
...
...
@@ -809,9 +873,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f
"pack_factor not set for parameter
{
param_name
}
."
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
enumerate
(
quant_states
.
items
()
)
:
for
seq
,
quant_state
in
quant_states
.
items
():
num_elements
[
seq
]
=
math
.
prod
(
quant_state
[
1
]
.
shape
)
//
pack_ratio
quant_state
.
shape
)
//
pack_ratio
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
87525fab
...
...
@@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment