Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8fe7fc86
Unverified
Commit
8fe7fc86
authored
Jun 30, 2025
by
Jee Jee Li
Committed by
GitHub
Jun 30, 2025
Browse files
[Quantization] Improve BitsAndBytesModelLoader (#20242)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
e936e401
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
51 deletions
+72
-51
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+72
-51
No files found.
vllm/model_executor/model_loader/bitsandbytes_loader.py
View file @
8fe7fc86
...
@@ -20,8 +20,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -20,8 +20,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
# yapf: enable
# yapf: enable
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -39,6 +37,8 @@ from vllm.model_executor.utils import (get_packed_modules_mapping,
...
@@ -39,6 +37,8 @@ from vllm.model_executor.utils import (get_packed_modules_mapping,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
# yapf conflicts with isort for this block
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -54,11 +54,17 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -54,11 +54,17 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self
.
unsharded_weights_modules
:
list
[
str
]
=
[]
self
.
unsharded_weights_modules
:
list
[
str
]
=
[]
# Save the module names that are sharded by column.
# Save the module names that are sharded by column.
self
.
column_sharded_weights_modules
:
list
[
str
]
=
[]
self
.
column_sharded_weights_modules
:
list
[
str
]
=
[]
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self
.
maybe_fused_weights_modules
:
dict
[
str
,
list
[
int
]]
=
{}
# Store all module names (from transformers) that support
# Store all module names (from transformers) that support
# BNB quantization.
# BNB quantization.
self
.
target_modules
:
list
[
str
]
=
[]
self
.
target_modules
:
list
[
str
]
=
[]
# mapping weight names from transformers to vllm.
# mapping weight names from transformers to vllm.
self
.
weight_mapper
:
Callable
=
lambda
name
:
name
self
.
weight_mapper
:
Callable
=
lambda
name
:
name
self
.
pre_quant
:
bool
=
False
self
.
load_8bit
:
bool
=
False
self
.
is_pool_model
:
bool
=
False
def
_get_weight_files
(
def
_get_weight_files
(
self
,
self
,
...
@@ -134,13 +140,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -134,13 +140,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
hf_weights_files
,
use_safetensors
return
hf_weights_files
,
use_safetensors
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
def
_maybe_pool_model
(
module_name
:
str
):
def
_maybe_pool_model
(
module_name
:
str
):
# For pool model, we need to add the prefix `model.`
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
# for the weight name if possible.
if
self
.
is_pool_model
and
self
.
target_modules
[
0
].
\
if
self
.
is_pool_model
and
self
.
target_modules
[
0
].
\
startswith
(
"model."
)
and
not
module_name
.
startswith
(
startswith
(
"model."
)
and
not
module_name
.
startswith
(
"model."
):
"model."
):
return
"model."
+
module_name
return
"model."
+
module_name
return
module_name
return
module_name
...
@@ -159,8 +166,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -159,8 +166,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# mapping weight names from transformers to vllm while preserving
# mapping weight names from transformers to vllm while preserving
# original names.
# original names.
mapped_name
=
self
.
weight_mapper
(
org_name
)
mapped_name
=
self
.
weight_mapper
(
org_name
)
mapped_name
=
_maybe_pool_model
(
mapped_name
)
mapped_name
=
_maybe_pool_model
(
mapped_name
)
yield
org_name
,
mapped_name
,
param
yield
org_name
,
mapped_name
,
param
...
@@ -168,8 +174,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -168,8 +174,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self
,
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
revision
:
Optional
[
str
],
pre_quant
:
bool
,
load_8bit
:
bool
,
)
->
tuple
[
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
dict
[
str
,
)
->
tuple
[
Generator
[
tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
dict
[
str
,
Any
]]:
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
"""Get an iterator to the model weights with bitsandbytes quantization,
...
@@ -192,8 +196,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -192,8 +196,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict
:
dict
[
str
,
Any
]
=
{}
quant_state_dict
:
dict
[
str
,
Any
]
=
{}
if
pre_quant
:
if
self
.
pre_quant
:
if
load_8bit
:
if
self
.
load_8bit
:
return
self
.
_quantized_8bit_generator
(
return
self
.
_quantized_8bit_generator
(
hf_weights_files
,
use_safetensors
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
quant_state_dict
),
quant_state_dict
...
@@ -390,10 +394,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -390,10 +394,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield
org_weight_name
,
processed_weight
yield
org_weight_name
,
processed_weight
def
_get_bnb_target_modules
(
self
,
model
:
nn
.
Module
)
->
None
:
def
_get_bnb_target_modules
(
self
,
model
:
nn
.
Module
)
->
None
:
"""
Identify and collect all modules that support BitsAndBytes
quantization.
"""
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
if
(
isinstance
(
module
,
LinearBase
)
and
if
(
isinstance
(
module
,
LinearBase
)
hasattr
(
module
.
quant_method
,
"quant_config"
)):
and
hasattr
(
module
.
quant_method
,
"quant_config"
)):
if
modules_info
:
=
self
.
modules_mapping
.
get_sub_modules
(
name
):
if
modules_info
:
=
self
.
modules_mapping
.
get_sub_modules
(
name
):
# Map vllm's names to transformers's names.
# Map vllm's names to transformers's names.
rep_name
,
sub_modules
=
modules_info
rep_name
,
sub_modules
=
modules_info
...
@@ -409,29 +416,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -409,29 +416,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
),
"vllm currently does not support BNB quantization for"
),
"vllm currently does not support BNB quantization for"
f
"
{
type
(
model
).
__name__
}
"
f
"
{
type
(
model
).
__name__
}
"
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
def
_classify_module_sharding
(
self
,
model
:
nn
.
Module
):
if
not
hasattr
(
model
,
"load_weights"
):
"""
raise
AttributeError
(
Categorize modules based on their weight sharding requirements
"The required method 'load_weights' is not defined in class"
for tensor parallelism.
f
"
{
type
(
model
).
__name__
}
."
)
"""
if
not
hasattr
(
model
,
"packed_modules_mapping"
):
raise
AttributeError
(
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found."
)
self
.
is_pool_model
=
is_pooling_model
(
model
)
self
.
modules_mapping
=
ParamMapping
(
get_packed_modules_mapping
(
model
))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
self
.
weight_mapper
=
lambda
name
:
hf_to_vllm_mapper
.
_map_name
(
name
)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self
.
maybe_fused_weights_modules
:
dict
[
str
,
list
[
int
]]
=
{}
self
.
_get_bnb_target_modules
(
model
)
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
# Some modules like `ReplicatedLinear` should not have their weights
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# sharded. The reason for implementing it this way is to avoid new
...
@@ -449,19 +438,27 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -449,19 +438,27 @@ class BitsAndBytesModelLoader(BaseModelLoader):
elif
isinstance
(
module
,
(
RowParallelLinear
,
)):
elif
isinstance
(
module
,
(
RowParallelLinear
,
)):
self
.
column_sharded_weights_modules
.
append
(
name
)
self
.
column_sharded_weights_modules
.
append
(
name
)
self
.
model_type
=
type
(
model
).
__name__
def
_verify_model_compatibility
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
"""
Verify that the model is compatible with BitsAndBytes quantization.
"""
if
not
hasattr
(
model
,
"load_weights"
):
raise
AttributeError
(
"The required method 'load_weights' is not defined in class"
f
"
{
type
(
model
).
__name__
}
."
)
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
if
not
hasattr
(
model
,
"packed_modules_mapping"
):
"May take a while ..."
)
raise
AttributeError
(
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found."
)
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
None
)
pre_quant
=
False
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get
(
"quant_method"
)
quant_method
=
quant_config
.
get
(
"quant_method"
)
if
quant_method
==
"bitsandbytes"
:
if
quant_method
==
"bitsandbytes"
:
pre_quant
=
True
self
.
pre_quant
=
True
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"BitsAndBytes loader does not support
{
quant_method
}
"
f
"BitsAndBytes loader does not support
{
quant_method
}
"
...
@@ -469,20 +466,43 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -469,20 +466,43 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# The quant_states in pre_quantized models cannot work with a split
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
# weight tensor. So TP does not work with pre_quantized bnb models.
if
pre_quant
and
get_tensor_model_parallel_world_size
()
>
1
:
if
self
.
pre_quant
and
get_tensor_model_parallel_world_size
()
>
1
:
raise
ValueError
(
raise
ValueError
(
"Prequant BitsAndBytes models with tensor parallelism is not "
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism."
)
"supported. Please try with pipeline parallelism."
)
if
self
.
pre_quant
:
self
.
load_8bit
=
quant_config
.
get
(
"load_in_8bit"
,
False
)
def
_initialize_loader_state
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
"""
Initialize the loader's internal state based on the model and
configuration.
"""
self
.
is_pool_model
=
is_pooling_model
(
model
)
self
.
modules_mapping
=
ParamMapping
(
get_packed_modules_mapping
(
model
))
load_8bit
=
False
# For some models like Molmo, we need to use hf_to_vllm_mapper
if
pre_quant
:
# to ensure correct loading of weights.
load_8bit
=
quant_config
.
get
(
"load_in_8bit"
,
False
)
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
self
.
weight_mapper
=
lambda
name
:
hf_to_vllm_mapper
.
_map_name
(
name
)
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_bnb_target_modules
(
model
)
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
self
.
_classify_module_sharding
(
model
)
model_config
.
revision
,
pre_quant
,
load_8bit
))
def
load_weights
(
self
,
model
:
nn
.
Module
,
model_config
:
ModelConfig
)
->
None
:
self
.
_verify_model_compatibility
(
model
,
model_config
)
self
.
_initialize_loader_state
(
model
,
model_config
)
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
"May take a while ..."
)
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
))
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
qweight_iterator
)
loaded_weights
=
model
.
load_weights
(
qweight_iterator
)
# Some models may have weights loading tracker unimplemented.
# Some models may have weights loading tracker unimplemented.
...
@@ -562,10 +582,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -562,10 +582,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
offsets
=
torch
.
tensor
(
offsets
).
cpu
()
offsets
=
torch
.
tensor
(
offsets
).
cpu
()
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
if
load_8bit
:
if
self
.
load_8bit
:
set_weight_attrs
(
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment