Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
49599150
Unverified
Commit
49599150
authored
Jun 19, 2025
by
Jee Jee Li
Committed by
GitHub
Jun 19, 2025
Browse files
[Quantization] Modify the logic of BNB double quantization (#19742)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
8d1e89d9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
3 deletions
+25
-3
vllm/model_executor/model_loader/bitsandbytes_loader.py
vllm/model_executor/model_loader/bitsandbytes_loader.py
+25
-3
No files found.
vllm/model_executor/model_loader/bitsandbytes_loader.py
View file @
49599150
...
@@ -492,8 +492,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -492,8 +492,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise
ValueError
(
"Following weights were not initialized from "
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
f
"checkpoint:
{
weights_not_loaded
}
"
)
torch
.
cuda
.
empty_cache
()
param_dict
=
dict
(
model
.
named_parameters
())
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
dict
[
str
,
dict
[
int
,
Any
]]
=
{}
stacked_quant_state_dict
:
dict
[
str
,
dict
[
int
,
Any
]]
=
{}
# TODO: Change this lazy import to normal import
# TODO: Change this lazy import to normal import
...
@@ -545,6 +543,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -545,6 +543,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for
param_name
,
param
in
param_dict
.
items
():
for
param_name
,
param
in
param_dict
.
items
():
if
param_name
in
stacked_quant_state_dict
:
if
param_name
in
stacked_quant_state_dict
:
quant_states
=
stacked_quant_state_dict
[
param_name
]
quant_states
=
stacked_quant_state_dict
[
param_name
]
# Dequantize double quantized values during weight loading.
dequantize_dq
(
quant_states
)
set_weight_attrs
(
param
,
{
"bnb_quant_state"
:
quant_states
})
set_weight_attrs
(
param
,
{
"bnb_quant_state"
:
quant_states
})
pack_ratio
=
getattr
(
param
,
"pack_factor"
,
-
1
)
pack_ratio
=
getattr
(
param
,
"pack_factor"
,
-
1
)
...
@@ -565,6 +565,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -565,6 +565,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
load_8bit
:
if
load_8bit
:
set_weight_attrs
(
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
torch
.
cuda
.
empty_cache
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
dequantize_dq
(
quant_states
:
dict
)
->
None
:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This comes
at the cost of increased memory usage.
"""
from
bitsandbytes.functional
import
dequantize_blockwise
for
_
,
quant_state
in
quant_states
.
items
():
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
if
quant_state
.
nested
:
absmax
=
dequantize_blockwise
(
quant_state
.
absmax
,
quant_state
.
state2
)
absmax
+=
quant_state
.
offset
if
absmax
.
dtype
!=
torch
.
float32
:
absmax
=
absmax
.
float
()
quant_state
.
absmax
=
absmax
quant_state
.
nested
=
False
quant_state
.
offset
=
None
quant_state
.
state2
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment