Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87525fab
Unverified
Commit
87525fab
authored
Jul 23, 2024
by
dongmao zhang
Committed by
GitHub
Jul 23, 2024
Browse files
[bitsandbytes]: support read bnb pre-quantized model (#5753)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
2f808e69
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
143 additions
and
39 deletions
+143
-39
docs/source/index.rst
docs/source/index.rst
+1
-0
docs/source/quantization/bnb.rst
docs/source/quantization/bnb.rst
+43
-0
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+14
-4
vllm/config.py
vllm/config.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-2
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+4
-21
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+76
-12
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+1
-0
No files found.
docs/source/index.rst
View file @
87525fab
...
@@ -105,6 +105,7 @@ Documentation
...
@@ -105,6 +105,7 @@ Documentation
quantization/supported_hardware
quantization/supported_hardware
quantization/auto_awq
quantization/auto_awq
quantization/bnb
quantization/fp8
quantization/fp8
quantization/fp8_e5m2_kvcache
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache
quantization/fp8_e4m3_kvcache
...
...
docs/source/quantization/bnb.rst
0 → 100644
View file @
87525fab
.. _bits_and_bytes:
BitsAndBytes
==================
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
Below are the steps to utilize BitsAndBytes with vLLM.
.. code-block:: console
$ pip install bitsandbytes>=0.42.0
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
And usually, these repositories have a config.json file that includes a quantization_config section.
Read quantized checkpoint.
--------------------------
.. code-block:: python
from vllm import LLM
import torch
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
model_id = "unsloth/tinyllama-bnb-4bit"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
Inflight quantization: load as 4bit quantization
------------------------------------------------
.. code-block:: python
from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
tests/quantization/test_bitsandbytes.py
View file @
87525fab
...
@@ -8,15 +8,20 @@ import torch
...
@@ -8,15 +8,20 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
models_to_test
=
[
(
'huggyllama/llama-7b'
,
'quantize model inflight'
),
(
'lllyasviel/omost-llama-3-8b-4bits'
,
'read pre-quantized model'
),
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
reason
=
'bitsandbytes is not supported on this GPU type.'
)
def
test_load_bnb_model
(
vllm_runner
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_to_test
)
with
vllm_runner
(
'huggyllama/llama-7b'
,
def
test_load_bnb_model
(
vllm_runner
,
model_name
,
description
)
->
None
:
with
vllm_runner
(
model_name
,
quantization
=
'bitsandbytes'
,
quantization
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
enforce_eager
=
True
)
as
llm
:
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
# check the weights in MLP & SelfAttention are quantized to torch.uint8
# check the weights in MLP & SelfAttention are quantized to torch.uint8
...
@@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
...
@@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
'To be or not to be, that is the question.'
'To be or not to be, that is the question.'
]
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
outputs
)
==
len
(
prompts
)
assert
len
(
outputs
)
==
len
(
prompts
)
for
index
in
range
(
len
(
outputs
)):
for
index
in
range
(
len
(
outputs
)):
# compare the first line of the output
# compare the first line of the output
actual_output
=
outputs
[
index
][
1
][
0
].
split
(
'
\n
'
,
1
)[
0
]
actual_output
=
outputs
[
index
][
1
][
0
].
split
(
'
\n
'
,
1
)[
0
]
expected_output
=
expected_outputs
[
index
].
split
(
'
\n
'
,
1
)[
0
]
expected_output
=
expected_outputs
[
index
].
split
(
'
\n
'
,
1
)[
0
]
assert
len
(
actual_output
)
>=
len
(
expected_output
),
(
f
'Actual
{
actual_output
}
should be larger than or equal to '
f
'expected
{
expected_output
}
'
)
actual_output
=
actual_output
[:
len
(
expected_output
)]
assert
actual_output
==
expected_output
,
(
assert
actual_output
==
expected_output
,
(
f
'Expected:
{
expected_output
}
, but got:
{
actual_output
}
'
)
f
'Expected:
{
expected_output
}
, but got:
{
actual_output
}
'
)
vllm/config.py
View file @
87525fab
...
@@ -591,9 +591,11 @@ class LoadConfig:
...
@@ -591,9 +591,11 @@ class LoadConfig:
mainly for profiling.
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
checkpoints.
"""
"""
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
load_format
:
Union
[
str
,
LoadFormat
,
"BaseModelLoader"
]
=
LoadFormat
.
AUTO
...
...
vllm/engine/arg_utils.py
View file @
87525fab
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
87525fab
...
@@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314
Reference: https://arxiv.org/abs/2305.14314
"""
"""
def
__init__
(
def
__init__
(
self
,
)
->
None
:
self
,
pass
adapter_name_or_path
:
str
,
target_modules
:
List
[
str
],
)
->
None
:
self
.
adapter_name_or_path
=
adapter_name_or_path
self
.
target_modules
=
target_modules
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
return
"BitsAndBytesConfig"
f
"BitsAndBytesConfig(adapter_name_or_path=
{
self
.
adapter_name_or_path
}
"
)
@
classmethod
@
classmethod
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
...
@@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BitsAndBytesConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BitsAndBytesConfig"
:
adapter_name
=
cls
.
get_from_keys
(
config
,
[
"adapter_name_or_path"
])
return
cls
()
default_target_modules
=
[
"gate_proj"
,
"down_proj"
,
"up_proj"
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
"o_proj"
]
if
adapter_name
==
""
:
target_modules
=
default_target_modules
else
:
target_modules
=
cls
.
get_from_keys
(
config
,
[
"target_modules"
])
return
cls
(
adapter_name
,
target_modules
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
...
...
vllm/model_executor/model_loader/loader.py
View file @
87525fab
...
@@ -702,8 +702,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -702,8 +702,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
return
hf_weights_files
,
matched_pattern
==
"*.safetensors"
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
else
:
return
pt_weights_iterator
(
hf_weights_files
)
def
_get_quantized_weights_iterator
(
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]
,
pre_quant
:
bool
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
"""Get an iterator to the model weights with bitsandbytes quantization,
...
@@ -712,6 +718,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -712,6 +718,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
# only load the bitsandbytes module when needed
try
:
try
:
import
bitsandbytes
import
bitsandbytes
from
bitsandbytes.functional
import
QuantState
if
bitsandbytes
.
__version__
<
"0.42.0"
:
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
"install bitsandbytes>=0.42.0."
)
...
@@ -725,13 +732,59 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -725,13 +732,59 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path
,
revision
)
model_name_or_path
,
revision
)
quant_state_dict
=
{}
quant_state_dict
=
{}
if
use_safetensors
:
weight_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weight_iterator
=
pt_weights_iterator
(
hf_weights_files
)
def
generator
():
def
quantized_checkpoint
()
->
Generator
:
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# TODO: only nf4 quantization is supported for now
if
weight_name
.
endswith
(
".quant_state.bitsandbytes__fp4"
):
raise
NotImplementedError
(
"Only bitsandbytes_nf4 quantization"
f
"is supported for now.
{
weight_name
}
is fp4 quantized"
)
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
]
=
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
].
cpu
().
data
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
weight_name
+
".quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
:
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
generator
()
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
...
@@ -749,6 +802,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -749,6 +802,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield
weight_name
,
processed_weight
yield
weight_name
,
processed_weight
if
pre_quant
:
return
quantized_checkpoint
(),
quant_state_dict
return
generator
(),
quant_state_dict
return
generator
(),
quant_state_dict
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
...
@@ -766,12 +821,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -766,12 +821,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
" May take a while ..."
)
qweight_iterator
,
quant_state_dict
=
(
is_quantized_checkpoint
=
False
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
model_config
.
revision
))
None
)
if
quant_config
is
not
None
and
quant_config
.
get
(
'quant_method'
)
==
"bitsandbytes"
:
is_quantized_checkpoint
=
True
qweight_iterator
,
quant_state_dict
=
\
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
is_quantized_checkpoint
)
model
.
load_weights
(
qweight_iterator
)
model
.
load_weights
(
qweight_iterator
)
torch
.
cuda
.
empty_cache
()
param_dict
=
dict
(
model
.
named_parameters
())
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
for
quant_param_name
in
quant_state_dict
:
for
quant_param_name
in
quant_state_dict
:
...
@@ -809,9 +873,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -809,9 +873,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f
"pack_factor not set for parameter
{
param_name
}
."
)
f
"pack_factor not set for parameter
{
param_name
}
."
)
num_elements
=
[
0
]
*
len
(
quant_states
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
enumerate
(
quant_states
.
items
()
)
:
for
seq
,
quant_state
in
quant_states
.
items
():
num_elements
[
seq
]
=
math
.
prod
(
num_elements
[
seq
]
=
math
.
prod
(
quant_state
[
1
]
.
shape
)
//
pack_ratio
quant_state
.
shape
)
//
pack_ratio
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
87525fab
...
@@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
...
@@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
,
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
)
->
QuantizationConfig
:
load_config
:
LoadConfig
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
# Read the quantization config from the HF model config, if available.
# Read the quantization config from the HF model config, if available.
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
hf_quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment