Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6781af56
Unverified
Commit
6781af56
authored
May 20, 2025
by
Jee Jee Li
Committed by
GitHub
May 19, 2025
Browse files
[Quantization] Pool model support bitsandbytes (#18087)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
1b15df25
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
3 deletions
+79
-3
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+64
-2
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+15
-1
No files found.
tests/quantization/test_bitsandbytes.py
View file @
6781af56
...
...
@@ -8,9 +8,11 @@ import gc
import
pytest
import
torch
from
transformers
import
BitsAndBytesConfig
from
tests.quantization.utils
import
is_quant_method_supported
from
..models.utils
import
check_embeddings_close
from
..utils
import
compare_two_settings
,
create_new_process_for_each_test
models_4bit_to_test
=
[
...
...
@@ -19,6 +21,10 @@ models_4bit_to_test = [
"quantize inflight model with both HF and Mistral format weights"
)
]
models_4bit_to_embedding_test
=
[
(
"intfloat/e5-mistral-7b-instruct"
,
"quantize embedding model inflight"
),
]
models_pre_qaunt_4bit_to_test
=
[
(
'PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed'
,
'read pre-quantized 4-bit FP4 model'
),
...
...
@@ -31,6 +37,12 @@ models_pre_quant_8bit_to_test = [
(
"yec019/fbopt-350m-8bit"
,
"read pre-quantized 8-bit opt model"
),
]
models_pre_quant_8bit_to_test
=
[
(
'meta-llama/Llama-Guard-3-8B-INT8'
,
'read pre-quantized llama 8-bit model'
),
(
"yec019/fbopt-350m-8bit"
,
"read pre-quantized 8-bit opt model"
),
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
...
...
@@ -39,7 +51,8 @@ models_pre_quant_8bit_to_test = [
def
test_load_4bit_bnb_model
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_name
,
description
)
->
None
:
hf_model_kwargs
=
{
"load_in_4bit"
:
True
}
hf_model_kwargs
=
dict
(
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
))
validate_generated_texts
(
hf_runner
,
vllm_runner
,
example_prompts
[:
1
],
model_name
,
False
,
hf_model_kwargs
)
...
...
@@ -77,7 +90,8 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
def
test_load_tp_4bit_bnb_model
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_name
,
description
)
->
None
:
hf_model_kwargs
=
{
"load_in_4bit"
:
True
}
hf_model_kwargs
=
dict
(
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
))
validate_generated_texts
(
hf_runner
,
vllm_runner
,
example_prompts
[:
1
],
...
...
@@ -113,6 +127,54 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
compare_two_settings
(
model_name
,
common_args
,
pp_args
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_4bit_to_embedding_test
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
create_new_process_for_each_test
()
def
test_4bit_bnb_embedding_model
(
model_name
,
description
,
hf_runner
,
vllm_runner
,
example_prompts
,
dtype
:
str
,
)
->
None
:
# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
# sentence_transformers will strip the input texts, see:
# https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159
# This makes the input_ids different between hf_model and vllm_model.
# So we need to strip the input texts to avoid test failing.
example_prompts
=
[
str
(
s
).
strip
()
for
s
in
example_prompts
]
# Inflight 4bit quantization
hf_model_kwargs
=
dict
(
quantization_config
=
BitsAndBytesConfig
(
load_in_4bit
=
True
))
with
hf_runner
(
model_name
,
dtype
=
dtype
,
model_kwargs
=
hf_model_kwargs
,
is_sentence_transformer
=
True
,
)
as
hf_model
:
hf_outputs
=
hf_model
.
encode
(
example_prompts
)
with
vllm_runner
(
model_name
,
task
=
"embed"
,
dtype
=
dtype
,
quantization
=
"bitsandbytes"
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompts
)
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
tol
=
5e-2
,
)
def
log_generated_texts
(
prompts
,
outputs
,
runner_name
):
logged_texts
=
[]
for
i
,
(
_
,
generated_text
)
in
enumerate
(
outputs
):
...
...
vllm/model_executor/model_loader/bitsandbytes_loader.py
View file @
6781af56
...
...
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models
import
is_pooling_model
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
...
...
@@ -133,6 +134,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
hf_weights_files
,
use_safetensors
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
def
_maybe_pool_model
(
module_name
:
str
):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if
self
.
is_pool_model
and
self
.
target_modules
[
0
].
\
startswith
(
"model."
)
and
not
module_name
.
startswith
(
"model."
):
return
"model."
+
module_name
return
module_name
if
use_safetensors
:
iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
...
...
@@ -148,6 +159,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name
=
self
.
weight_mapper
(
org_name
)
mapped_name
=
_maybe_pool_model
(
mapped_name
)
yield
org_name
,
mapped_name
,
param
def
_get_quantized_weights_iterator
(
...
...
@@ -405,7 +419,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise
AttributeError
(
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found."
)
self
.
is_pool_model
=
is_pooling_model
(
model
)
self
.
modules_mapping
=
ParamMapping
(
copy
.
deepcopy
(
model
.
packed_modules_mapping
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment