Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
726efc6a
Unverified
Commit
726efc6a
authored
Mar 28, 2025
by
Jee Jee Li
Committed by
GitHub
Mar 28, 2025
Browse files
[Quantization][V1] BitsAndBytes support V1 (#15611)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
bd45912b
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
52 additions
and
24 deletions
+52
-24
tests/models/encoder_decoder/vision_language/test_mllama.py
tests/models/encoder_decoder/vision_language/test_mllama.py
+0
-1
tests/models/test_transformers.py
tests/models/test_transformers.py
+0
-1
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+0
-3
vllm/config.py
vllm/config.py
+4
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+45
-16
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+2
-0
No files found.
tests/models/encoder_decoder/vision_language/test_mllama.py
View file @
726efc6a
...
...
@@ -425,7 +425,6 @@ def test_bnb_regression(
max_model_len
=
4096
,
max_num_seqs
=
2
,
quantization
=
"bitsandbytes"
,
load_format
=
"bitsandbytes"
,
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
...
...
tests/models/test_transformers.py
View file @
726efc6a
...
...
@@ -72,7 +72,6 @@ def test_distributed(
"meta-llama/Llama-3.2-1B-Instruct"
,
{
"quantization"
:
"bitsandbytes"
,
"load_format"
:
"bitsandbytes"
,
},
),
])
...
...
tests/quantization/test_bitsandbytes.py
View file @
726efc6a
...
...
@@ -101,8 +101,6 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
"--enable-prefix-caching"
,
"--quantization"
,
"bitsandbytes"
,
"--load-format"
,
"bitsandbytes"
,
"--gpu-memory-utilization"
,
"0.7"
,
]
...
...
@@ -137,7 +135,6 @@ def validate_generated_texts(hf_runner,
# when using distributed inference
with
vllm_runner
(
model_name
,
quantization
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
tensor_parallel_size
=
vllm_tp_size
,
enforce_eager
=
False
)
as
llm
:
vllm_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
...
...
vllm/config.py
View file @
726efc6a
...
...
@@ -682,8 +682,9 @@ class ModelConfig:
def
_verify_bnb_config
(
self
)
->
None
:
"""
The current version of bitsandbytes (0.4
4.0
) with 8-bit models does not
The current version of bitsandbytes (0.4
5.3
) with 8-bit models does not
yet support CUDA graph.
# TODO Remove this when bitsandbytes supports.
"""
is_bitsandbytes
=
self
.
quantization
==
"bitsandbytes"
has_quantization_config
=
(
getattr
(
self
.
hf_config
,
...
...
@@ -698,8 +699,9 @@ class ModelConfig:
not
self
.
enforce_eager
,
]):
logger
.
warning
(
"CUDA graph is not supported on BitAndBytes 8bit yet, "
"CUDA graph is not supported on Bit
s
AndBytes 8bit yet, "
"fallback to the eager mode."
)
self
.
enforce_eager
=
True
def
_verify_with_expert_parallelism
(
self
)
->
None
:
...
...
vllm/engine/arg_utils.py
View file @
726efc6a
...
...
@@ -1616,7 +1616,7 @@ class EngineArgs:
return
False
# Some quantization is not compatible with torch.compile.
V1_UNSUPPORTED_QUANT
=
[
"bitsandbytes"
,
"gguf"
]
V1_UNSUPPORTED_QUANT
=
[
"gguf"
]
if
model_config
.
quantization
in
V1_UNSUPPORTED_QUANT
:
_raise_or_fallback
(
feature_name
=
f
"--quantization
{
model_config
.
quantization
}
"
,
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
726efc6a
...
...
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.utils
import
direct_register_custom_op
class
BitsAndBytesConfig
(
QuantizationConfig
):
...
...
@@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
matmul_4bit
original_type
=
x
.
dtype
original_shape
=
x
.
shape
reshape_after_matmul
=
False
...
...
@@ -343,7 +341,27 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out_dim_1
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
apply_bnb_4bit
(
bf_x
,
qweight
,
offsets
,
out
)
out
=
out
.
to
(
original_type
)
if
reshape_after_matmul
:
out
=
out
.
view
(
*
original_shape
[:
-
1
],
out
.
size
(
-
1
))
if
bias
is
not
None
:
out
+=
bias
return
out
def
_apply_bnb_4bit
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
)
->
None
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
matmul_4bit
quant_states
=
weight
.
bnb_quant_state
current_index
=
0
for
i
in
range
(
len
(
quant_states
)):
output_size
=
quant_states
[
i
].
shape
[
0
]
...
...
@@ -352,16 +370,27 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out
[:,
current_index
:
current_index
+
output_size
]
=
matmul_4bit
(
bf_x
,
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]].
t
(),
quant_states
[
i
])
x
,
weight
[
offsets
[
i
]:
offsets
[
i
+
1
]].
t
(),
quant_states
[
i
])
current_index
+=
output_size
out
=
out
.
to
(
original_type
)
if
reshape_after_matmul
:
out
=
out
.
view
(
*
original_shape
[:
-
1
],
out
.
size
(
-
1
))
if
bias
is
not
None
:
out
+=
bias
return
out
def
_apply_bnb_4bit_fake
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
offsets
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
)
->
None
:
return
try
:
direct_register_custom_op
(
op_name
=
"apply_bnb_4bit"
,
op_func
=
_apply_bnb_4bit
,
mutates_args
=
[
"out"
],
fake_impl
=
_apply_bnb_4bit_fake
,
)
apply_bnb_4bit
=
torch
.
ops
.
vllm
.
apply_bnb_4bit
except
AttributeError
as
error
:
raise
error
vllm/model_executor/model_loader/loader.py
View file @
726efc6a
...
...
@@ -1259,6 +1259,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pack_ratio
)
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
# Make torch infer_schema happy
offsets
=
torch
.
tensor
(
offsets
).
cpu
()
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
if
load_8bit
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment