Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4664ceaa
Unverified
Commit
4664ceaa
authored
Aug 29, 2024
by
chenqianfzh
Committed by
GitHub
Aug 29, 2024
Browse files
support bitsandbytes 8-bit and FP4 quantized models (#7445)
parent
257afc37
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
437 additions
and
191 deletions
+437
-191
tests/conftest.py
tests/conftest.py
+6
-0
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+98
-68
vllm/config.py
vllm/config.py
+2
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+10
-8
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+195
-36
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+126
-79
No files found.
tests/conftest.py
View file @
4664ceaa
...
...
@@ -209,8 +209,14 @@ class HfRunner:
def
wrap_device
(
self
,
input
:
_T
)
->
_T
:
if
not
is_cpu
():
# Check if the input is already on the GPU
if
hasattr
(
input
,
'device'
)
and
input
.
device
.
type
==
"cuda"
:
return
input
# Already on GPU, no need to move
return
input
.
to
(
"cuda"
)
else
:
# Check if the input is already on the CPU
if
hasattr
(
input
,
'device'
)
and
input
.
device
.
type
==
"cpu"
:
return
input
# Already on CPU, no need to move
return
input
.
to
(
"cpu"
)
def
__init__
(
...
...
tests/quantization/test_bitsandbytes.py
View file @
4664ceaa
...
...
@@ -2,85 +2,115 @@
Run `pytest tests/quantization/test_bitsandbytes.py`.
'''
import
gc
import
pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
SamplingParams
models_to_test
=
[
models_
4bit_
to_test
=
[
(
'huggyllama/llama-7b'
,
'quantize model inflight'
),
(
'lllyasviel/omost-llama-3-8b-4bits'
,
'read pre-quantized model'
),
]
models_pre_qaunt_4bit_to_test
=
[
(
'lllyasviel/omost-llama-3-8b-4bits'
,
'read pre-quantized 4-bit NF4 model'
),
(
'PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed'
,
'read pre-quantized 4-bit FP4 model'
),
]
models_pre_quant_8bit_to_test
=
[
(
'meta-llama/Llama-Guard-3-8B-INT8'
,
'read pre-quantized 8-bit model'
),
]
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_4bit_to_test
)
def
test_load_4bit_bnb_model
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_name
,
description
)
->
None
:
hf_model_kwargs
=
{
"load_in_4bit"
:
True
}
validate_generated_texts
(
hf_runner
,
vllm_runner
,
example_prompts
[:
1
],
model_name
,
hf_model_kwargs
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_pre_qaunt_4bit_to_test
)
def
test_load_pre_quant_4bit_bnb_model
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_name
,
description
)
->
None
:
validate_generated_texts
(
hf_runner
,
vllm_runner
,
example_prompts
[:
1
],
model_name
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_to_test
)
def
test_load_bnb_model
(
vllm_runner
,
model_name
,
description
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"model_name, description"
,
models_pre_quant_8bit_to_test
)
def
test_load_8bit_bnb_model
(
hf_runner
,
vllm_runner
,
example_prompts
,
model_name
,
description
)
->
None
:
validate_generated_texts
(
hf_runner
,
vllm_runner
,
example_prompts
[:
1
],
model_name
)
def
log_generated_texts
(
prompts
,
outputs
,
runner_name
):
logged_texts
=
[]
for
i
,
(
_
,
generated_text
)
in
enumerate
(
outputs
):
log_entry
=
{
"prompt"
:
prompts
[
i
],
"runner_name"
:
runner_name
,
"generated_text"
:
generated_text
,
}
logged_texts
.
append
(
log_entry
)
return
logged_texts
def
validate_generated_texts
(
hf_runner
,
vllm_runner
,
prompts
,
model_name
,
hf_model_kwargs
=
None
):
if
hf_model_kwargs
is
None
:
hf_model_kwargs
=
{}
# Run with HF runner
with
hf_runner
(
model_name
,
model_kwargs
=
hf_model_kwargs
)
as
llm
:
hf_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
hf_logs
=
log_generated_texts
(
prompts
,
hf_outputs
,
"HfRunner"
)
# Clean up the GPU memory for the next test
torch
.
cuda
.
synchronize
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
#Run with vLLM runner
with
vllm_runner
(
model_name
,
quantization
=
'bitsandbytes'
,
load_format
=
'bitsandbytes'
,
enforce_eager
=
True
)
as
llm
:
model
=
llm
.
model
.
llm_engine
.
model_executor
.
driver_worker
.
model_runner
.
model
# noqa: E501
# check the weights in MLP & SelfAttention are quantized to torch.uint8
qweight
=
model
.
model
.
layers
[
0
].
mlp
.
gate_up_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected gate_up_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
qweight
=
model
.
model
.
layers
[
0
].
mlp
.
down_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected down_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
qweight
=
model
.
model
.
layers
[
0
].
self_attn
.
o_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected o_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
qweight
=
model
.
model
.
layers
[
0
].
self_attn
.
qkv_proj
.
qweight
assert
qweight
.
dtype
==
torch
.
uint8
,
(
f
'Expected qkv_proj dtype torch.uint8 but got
{
qweight
.
dtype
}
'
)
# some weights should not be quantized
weight
=
model
.
lm_head
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'lm_head weight dtype should not be torch.uint8'
)
weight
=
model
.
model
.
embed_tokens
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'embed_tokens weight dtype should not be torch.uint8'
)
weight
=
model
.
model
.
layers
[
0
].
input_layernorm
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'input_layernorm weight dtype should not be torch.uint8'
)
weight
=
model
.
model
.
layers
[
0
].
post_attention_layernorm
.
weight
assert
weight
.
dtype
!=
torch
.
uint8
,
(
'input_layernorm weight dtype should not be torch.uint8'
)
# check the output of the model is expected
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
logprobs
=
1
,
prompt_logprobs
=
1
,
max_tokens
=
8
)
prompts
=
[
'That which does not kill us'
,
'To be or not to be,'
]
expected_outputs
=
[
'That which does not kill us makes us stronger.'
,
'To be or not to be, that is the question.'
]
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
outputs
)
==
len
(
prompts
)
for
index
in
range
(
len
(
outputs
)):
# compare the first line of the output
actual_output
=
outputs
[
index
][
1
][
0
].
split
(
'
\n
'
,
1
)[
0
]
expected_output
=
expected_outputs
[
index
].
split
(
'
\n
'
,
1
)[
0
]
assert
len
(
actual_output
)
>=
len
(
expected_output
),
(
f
'Actual
{
actual_output
}
should be larger than or equal to '
f
'expected
{
expected_output
}
'
)
actual_output
=
actual_output
[:
len
(
expected_output
)]
assert
actual_output
==
expected_output
,
(
f
'Expected:
{
expected_output
}
, but got:
{
actual_output
}
'
)
enforce_eager
=
True
,
gpu_memory_utilization
=
0.8
)
as
llm
:
vllm_outputs
=
llm
.
generate_greedy
(
prompts
,
8
)
vllm_logs
=
log_generated_texts
(
prompts
,
vllm_outputs
,
"VllmRunner"
)
# Clean up the GPU memory for the next test
torch
.
cuda
.
synchronize
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# Compare the generated strings
for
hf_log
,
vllm_log
in
zip
(
hf_logs
,
vllm_logs
):
hf_str
=
hf_log
[
"generated_text"
]
vllm_str
=
vllm_log
[
"generated_text"
]
prompt
=
hf_log
[
"prompt"
]
assert
hf_str
==
vllm_str
,
(
f
"Model:
{
model_name
}
"
f
"Mismatch between HF and vLLM outputs:
\n
"
f
"Prompt:
{
prompt
}
\n
"
f
"HF Output: '
{
hf_str
}
'
\n
"
f
"vLLM Output: '
{
vllm_str
}
'"
)
vllm/config.py
View file @
4664ceaa
...
...
@@ -405,6 +405,8 @@ class ModelConfig:
raise
ValueError
(
"BitAndBytes quantization with TP or PP is not supported yet."
)
# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if
self
.
quantization
==
"bitsandbytes"
and
self
.
enforce_eager
is
False
:
logger
.
warning
(
"CUDA graph is not supported on BitAndBytes yet, "
"fallback to the eager mode."
)
...
...
vllm/model_executor/layers/linear.py
View file @
4664ceaa
...
...
@@ -36,9 +36,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_bitsandbytes_shard
(
param
:
Parameter
,
qkv_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
def
adjust_bitsandbytes_
4bit_
shard
(
param
:
Parameter
,
qkv_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total
,
_
=
qkv_offsets
[
"total"
]
...
...
@@ -505,8 +505,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
if
use_bitsandbytes_4bit
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
...
...
@@ -858,8 +859,9 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
if
use_bitsandbytes_4bit
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
num_heads
*
self
.
head_size
),
"k"
:
(
self
.
num_heads
*
self
.
head_size
,
...
...
@@ -871,7 +873,7 @@ class QKVParallelLinear(ColumnParallelLinear):
((
self
.
num_heads
+
2
*
self
.
num_kv_heads
)
*
self
.
head_size
,
0
)
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
shard_size
,
shard_offset
=
adjust_bitsandbytes_
4bit_
shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
is_gguf_weight
:
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
4664ceaa
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
...
...
@@ -15,8 +14,28 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314
"""
def
__init__
(
self
,
)
->
None
:
pass
def
__init__
(
self
,
load_in_8bit
:
bool
=
False
,
load_in_4bit
:
bool
=
True
,
bnb_4bit_compute_dtype
:
str
=
"float32"
,
bnb_4bit_quant_type
:
str
=
"fp4"
,
bnb_4bit_use_double_quant
:
bool
=
False
,
llm_int8_enable_fp32_cpu_offload
:
bool
=
False
,
llm_int8_has_fp16_weight
:
bool
=
False
,
llm_int8_skip_modules
:
Optional
[
Any
]
=
None
,
llm_int8_threshold
:
float
=
0.0
,
)
->
None
:
self
.
load_in_8bit
=
load_in_8bit
self
.
load_in_4bit
=
load_in_4bit
self
.
bnb_4bit_compute_dtype
=
bnb_4bit_compute_dtype
self
.
bnb_4bit_quant_type
=
bnb_4bit_quant_type
self
.
bnb_4bit_use_double_quant
=
bnb_4bit_use_double_quant
self
.
llm_int8_enable_fp32_cpu_offload
=
llm_int8_enable_fp32_cpu_offload
self
.
llm_int8_has_fp16_weight
=
llm_int8_has_fp16_weight
self
.
llm_int8_skip_modules
=
llm_int8_skip_modules
self
.
llm_int8_threshold
=
llm_int8_threshold
def
__repr__
(
self
)
->
str
:
return
"BitsAndBytesConfig"
...
...
@@ -41,7 +60,46 @@ class BitsAndBytesConfig(QuantizationConfig):
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BitsAndBytesConfig"
:
return
cls
()
def
get_safe_value
(
config
,
keys
,
default_value
=
None
):
try
:
value
=
cls
.
get_from_keys
(
config
,
keys
)
return
value
if
value
is
not
None
else
default_value
except
ValueError
:
return
default_value
load_in_8bit
=
get_safe_value
(
config
,
[
"load_in_8bit"
],
default_value
=
False
)
load_in_4bit
=
get_safe_value
(
config
,
[
"load_in_4bit"
],
default_value
=
True
)
bnb_4bit_compute_dtype
=
get_safe_value
(
config
,
[
"bnb_4bit_compute_dtype"
],
default_value
=
"float32"
)
bnb_4bit_quant_type
=
get_safe_value
(
config
,
[
"bnb_4bit_quant_type"
],
default_value
=
"fp4"
)
bnb_4bit_use_double_quant
=
get_safe_value
(
config
,
[
"bnb_4bit_use_double_quant"
],
default_value
=
False
)
llm_int8_enable_fp32_cpu_offload
=
get_safe_value
(
config
,
[
"llm_int8_enable_fp32_cpu_offload"
],
default_value
=
False
)
llm_int8_has_fp16_weight
=
get_safe_value
(
config
,
[
"llm_int8_has_fp16_weight"
],
default_value
=
False
)
llm_int8_skip_modules
=
get_safe_value
(
config
,
[
"llm_int8_skip_modules"
],
default_value
=
[])
llm_int8_threshold
=
get_safe_value
(
config
,
[
"llm_int8_threshold"
],
default_value
=
0.0
)
return
cls
(
load_in_8bit
=
load_in_8bit
,
load_in_4bit
=
load_in_4bit
,
bnb_4bit_compute_dtype
=
bnb_4bit_compute_dtype
,
bnb_4bit_quant_type
=
bnb_4bit_quant_type
,
bnb_4bit_use_double_quant
=
bnb_4bit_use_double_quant
,
llm_int8_enable_fp32_cpu_offload
=
llm_int8_enable_fp32_cpu_offload
,
llm_int8_has_fp16_weight
=
llm_int8_has_fp16_weight
,
llm_int8_skip_modules
=
llm_int8_skip_modules
,
llm_int8_threshold
=
llm_int8_threshold
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
...
...
@@ -78,39 +136,58 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
quant_ratio
=
0
if
params_dtype
.
is_floating_point
:
quant_ratio
=
torch
.
finfo
(
params_dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
from
bitsandbytes.nn
import
Int8Params
def
calculate_quant_ratio
(
dtype
):
if
dtype
.
is_floating_point
:
return
torch
.
finfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
else
:
return
torch
.
iinfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
def
create_qweight_for_8bit
():
qweight
=
Int8Params
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
has_fp16_weights
=
self
.
quant_config
.
llm_int8_has_fp16_weight
,
requires_grad
=
False
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
0
,
"pack_factor"
:
1
,
"use_bitsandbytes_8bit"
:
True
,
"generation"
:
0
})
return
qweight
def
create_qweight_for_4bit
():
quant_ratio
=
calculate_quant_ratio
(
params_dtype
)
total_size
=
input_size_per_partition
*
sum
(
output_partition_sizes
)
if
total_size
%
quant_ratio
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape."
)
qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
total_size
//
quant_ratio
,
1
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
0
,
"pack_factor"
:
quant_ratio
,
"use_bitsandbytes_4bit"
:
True
})
return
qweight
if
self
.
quant_config
.
load_in_8bit
:
qweight
=
create_qweight_for_8bit
()
else
:
quant_ratio
=
torch
.
iinfo
(
params_dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
if
input_size_per_partition
*
sum
(
output_partition_sizes
)
%
quant_ratio
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. "
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
*
sum
(
output_partition_sizes
)
//
quant_ratio
,
1
,
dtype
=
torch
.
uint8
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
# In bitsandbytes, a tensor of shape [n,m] is quantized to
#[n*m/pack_ratio, 1],so the output_dim is 0
"output_dim"
:
0
,
"pack_factor"
:
quant_ratio
,
"use_bitsandbytes"
:
True
,
})
qweight
=
create_qweight_for_4bit
()
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
...
...
@@ -119,6 +196,88 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
load_in_8bit
:
return
self
.
_apply_8bit_weight
(
layer
,
x
,
bias
)
else
:
return
self
.
_apply_4bit_weight
(
layer
,
x
,
bias
)
def
_apply_8bit_weight
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
MatmulLtState
,
matmul
original_type
=
x
.
dtype
bf_x
=
x
.
to
(
torch
.
bfloat16
)
qweight
=
layer
.
qweight
offsets
=
qweight
.
bnb_shard_offsets
quant_states
=
qweight
.
bnb_quant_state
matmul_states
=
qweight
.
matmul_state
generation
=
qweight
.
generation
out_dim_0
=
x
.
shape
[
0
]
out_dim_1
=
sum
(
[
quant_state
[
1
].
shape
[
0
]
for
quant_state
in
quant_states
.
items
()])
out
=
torch
.
empty
(
out_dim_0
,
out_dim_1
,
dtype
=
torch
.
float16
,
device
=
x
.
device
)
current_index
=
0
for
i
in
range
(
len
(
quant_states
)):
output_size
=
quant_states
[
i
].
shape
[
0
]
# in profile_run or the first generation of inference,
# create new matmul_states
if
generation
==
0
or
generation
==
1
:
matmul_states
[
i
]
=
MatmulLtState
()
matmul_states
[
i
].
CB
=
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
matmul_states
[
i
].
SCB
=
quant_states
[
i
]
matmul_states
[
i
].
threshold
=
(
self
.
quant_config
.
llm_int8_threshold
)
matmul_states
[
i
].
has_fp16_weights
=
(
self
.
quant_config
.
llm_int8_has_fp16_weight
)
matmul_states
[
i
].
is_training
=
False
if
matmul_states
[
i
].
threshold
>
0.0
and
not
matmul_states
[
i
].
has_fp16_weights
:
matmul_states
[
i
].
use_pool
=
True
new_x
=
bf_x
.
unsqueeze
(
0
)
out
[:,
current_index
:
current_index
+
output_size
]
=
matmul
(
new_x
,
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]],
state
=
matmul_states
[
i
])
current_index
+=
output_size
# only update the matmul_states if it is not profile_run
if
(
generation
>
0
and
not
self
.
quant_config
.
llm_int8_has_fp16_weight
and
matmul_states
[
i
].
CB
is
not
None
and
matmul_states
[
i
].
CxB
is
not
None
):
del
matmul_states
[
i
].
CB
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
=
matmul_states
[
i
].
CxB
out
=
out
.
to
(
original_type
)
if
bias
is
not
None
:
out
+=
bias
qweight
.
generation
+=
1
return
out
def
_apply_4bit_weight
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
matmul_4bit
...
...
vllm/model_executor/model_loader/loader.py
View file @
4664ceaa
...
...
@@ -771,7 +771,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
pt_weights_iterator
(
hf_weights_files
)
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
pre_quant
:
bool
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
pre_quant
:
bool
,
load_8bit
:
bool
,
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
...
...
@@ -780,11 +784,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
try
:
import
bitsandbytes
from
bitsandbytes.functional
import
QuantState
if
bitsandbytes
.
__version__
<
"0.42.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0."
)
from
bitsandbytes.functional
import
quantize_4bit
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
...
...
@@ -793,80 +795,111 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
)
quant_state_dict
=
{}
def
quantized_checkpoint
()
->
Generator
:
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# TODO: only nf4 quantization is supported for now
if
weight_name
.
endswith
(
".quant_state.bitsandbytes__fp4"
):
raise
NotImplementedError
(
"Only bitsandbytes_nf4 quantization"
f
"is supported for now.
{
weight_name
}
is fp4 quantized"
)
temp_state_dict
[
weight_name
]
=
weight_tensor
quant_state_dict
:
Dict
[
str
,
Any
]
=
{}
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
]
=
quant_state
[
param_name
+
".quant_state.bitsandbytes__nf4"
].
cpu
().
data
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
weight_name
+
".quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
:
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
generator
()
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
compress_statistics
=
True
,
quant_type
=
"nf4"
)
quant_state_dict
[
weight_name
]
=
quant_state
else
:
processed_weight
=
weight_tensor
if
pre_quant
:
if
load_8bit
:
return
self
.
_quantized_8bit_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
else
:
return
self
.
_quantized_4bit_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
yield
weight_name
,
processed_weight
return
self
.
_unquantized_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
if
pre_quant
:
return
quantized_checkpoint
(),
quant_state_dict
return
generator
(),
quant_state_dict
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
lower
().
endswith
(
".scb"
):
continue
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".qweight"
)
quant_state_dict
[
weight_key
]
=
weight_tensor
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
endswith
(
".weight"
):
continue
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
if
qweight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
qweight_name
,
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
_quantized_4bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
QuantState
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if
"quant_state.bitsandbytes"
in
weight_name
:
temp_state_dict
[
weight_name
]
=
weight_tensor
.
cpu
().
data
else
:
temp_state_dict
[
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
)
or
\
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
\
in
temp_state_dict
):
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
),
weight_tensor
else
:
yield
weight_name
,
weight_tensor
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
quantize_4bit
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# bitsandbytes requires data in GPU
loaded_weight
=
weight_tensor
.
cuda
().
data
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
compress_statistics
=
True
,
quant_type
=
"nf4"
)
quant_state_dict
[
weight_name
]
=
quant_state
else
:
processed_weight
=
weight_tensor
yield
weight_name
,
processed_weight
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
)
->
None
:
...
...
@@ -883,16 +916,26 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
is_quantized_checkpoint
=
False
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
if
quant_config
is
not
None
and
quant_config
.
get
(
'quant_method'
)
==
"bitsandbytes"
:
is_quantized_checkpoint
=
True
pre_quant
=
False
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get
(
'quant_method'
)
if
quant_method
==
"bitsandbytes"
:
pre_quant
=
True
else
:
raise
ValueError
(
f
"BitsAndBytes loader does not support
{
quant_method
}
"
"quantization"
)
load_8bit
=
False
if
pre_quant
:
load_8bit
=
quant_config
.
get
(
'load_in_8bit'
,
False
)
qweight_iterator
,
quant_state_dict
=
\
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
is
_quant
ized_checkpoin
t
)
model_config
.
model
,
model_config
.
revision
,
pre
_quant
,
load_8bi
t
)
model
.
load_weights
(
qweight_iterator
)
...
...
@@ -942,6 +985,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
if
load_8bit
:
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
lora_config
:
Optional
[
LoRAConfig
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment