Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96b23621
Unverified
Commit
96b23621
authored
Feb 04, 2025
by
Jee Jee Li
Committed by
GitHub
Feb 04, 2025
Browse files
[Misc] Add BNB quantization for Whisper (#12381)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
c36ac98d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
82 additions
and
44 deletions
+82
-44
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+60
-42
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+7
-0
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+15
-2
No files found.
vllm/model_executor/model_loader/loader.py
View file @
96b23621
...
@@ -803,9 +803,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -803,9 +803,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
else
:
iterator
=
pt_weights_iterator
(
hf_weights_files
)
iterator
=
pt_weights_iterator
(
hf_weights_files
)
for
name
,
param
in
iterator
:
for
org_name
,
param
in
iterator
:
# mapping weight names from transformers to vllm.
# mapping weight names from transformers to vllm while preserving
yield
self
.
weight_mapper
(
name
),
param
# original names.
mapped_name
=
self
.
weight_mapper
(
org_name
)
yield
org_name
,
mapped_name
,
param
def
_get_quantized_weights_iterator
(
def
_get_quantized_weights_iterator
(
self
,
self
,
...
@@ -866,24 +868,30 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -866,24 +868,30 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
quant_state_dict
)
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
(
hf_weights_files
,
use_safetensors
):
org_weight_name
,
if
not
weight_name
.
lower
().
endswith
(
".scb"
):
mapped_weight_name
,
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
mapped_weight_name
.
lower
().
endswith
(
".scb"
):
continue
continue
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".weight"
)
weight_key
=
mapped_
weight_name
.
lower
().
replace
(
".scb"
,
".weight"
)
quant_state_dict
[
weight_key
]
=
weight_tensor
quant_state_dict
[
weight_key
]
=
weight_tensor
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
(
hf_weights_files
,
use_safetensors
):
org_weight_name
,
if
self
.
_is_8bit_weight_name
(
weight_name
):
mapped_weight_name
,
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
self
.
_is_8bit_weight_name
(
mapped_weight_name
):
continue
continue
if
weight_name
in
quant_state_dict
:
if
mapped_
weight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
weight_name
,
weight_tensor
yield
org_
weight_name
,
weight_tensor
else
:
else
:
yield
weight_name
,
weight_tensor
yield
org_
weight_name
,
weight_tensor
def
_quantized_4bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
def
_quantized_4bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
quant_state_dict
)
->
Generator
:
...
@@ -893,15 +901,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -893,15 +901,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
use_safetensors
)
temp_state_dict
=
{}
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
for
(
if
not
self
.
_is_4bit_weight_name
(
weight_name
):
org_weight_name
,
mapped_weight_name
,
weight_tensor
,
)
in
weight_iterator
:
if
not
self
.
_is_4bit_weight_name
(
mapped_weight_name
):
continue
continue
# bitsandbytes library requires
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
# weight.quant_state.bitsandbytes__* in CPU
if
"quant_state.bitsandbytes"
in
weight_name
:
if
"quant_state.bitsandbytes"
in
mapped_
weight_name
:
temp_state_dict
[
weight_name
]
=
weight_tensor
.
cpu
().
data
temp_state_dict
[
mapped_
weight_name
]
=
weight_tensor
.
cpu
().
data
else
:
else
:
temp_state_dict
[
weight_name
]
=
weight_tensor
temp_state_dict
[
mapped_
weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
def
_parse_quant_state
(
param_name
:
str
,
...
@@ -915,20 +927,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -915,20 +927,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Second iterate over all prequant and normal weights
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
(
hf_weights_files
,
use_safetensors
):
org_weight_name
,
if
self
.
_is_4bit_weight_name
(
weight_name
):
mapped_weight_name
,
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
self
.
_is_4bit_weight_name
(
mapped_weight_name
):
continue
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
if
(
f
"
{
mapped_
weight_name
}
.quant_state.bitsandbytes__nf4"
in
temp_state_dict
)
or
(
in
temp_state_dict
)
or
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
f
"
{
mapped_
weight_name
}
.quant_state.bitsandbytes__fp4"
in
temp_state_dict
):
in
temp_state_dict
):
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
quant_state
=
_parse_quant_state
(
mapped_weight_name
,
quant_state_dict
[
weight_name
]
=
quant_state
temp_state_dict
)
yield
weight_name
,
weight_tensor
quant_state_dict
[
mapped_weight_name
]
=
quant_state
yield
org_weight_name
,
weight_tensor
else
:
else
:
yield
weight_name
,
weight_tensor
yield
org_
weight_name
,
weight_tensor
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
quant_state_dict
)
->
Generator
:
...
@@ -937,18 +953,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -937,18 +953,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
(
hf_weights_files
,
use_safetensors
):
org_weight_name
,
if
any
(
target_module
in
weight_name
for
target_module
in
mapped_weight_name
,
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
mapped_weight_name
for
target_module
in
self
.
target_modules
)
and
mapped_weight_name
.
endswith
(
".weight"
):
# Without sharding
# Without sharding
if
any
(
if
any
(
weight_name
.
startswith
(
module
)
mapped_
weight_name
.
startswith
(
module
)
for
module
in
self
.
unsharded_weights_modules
):
for
module
in
self
.
unsharded_weights_modules
):
weight_sub_tensor
=
weight_tensor
weight_sub_tensor
=
weight_tensor
# Shard by column
# Shard by column
elif
any
(
elif
any
(
weight_name
.
startswith
(
module
)
mapped_
weight_name
.
startswith
(
module
)
for
module
in
self
.
column_sharded_weights_modules
):
for
module
in
self
.
column_sharded_weights_modules
):
total_size
=
weight_tensor
.
size
(
-
1
)
total_size
=
weight_tensor
.
size
(
-
1
)
start_index
=
total_size
//
tp_size
*
tp_rank
start_index
=
total_size
//
tp_size
*
tp_rank
...
@@ -958,14 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -958,14 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Weights have fused on disk. In this case, we assume that the
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
# weight and module use same name.
elif
any
(
elif
any
(
weight_name
.
startswith
(
module
)
mapped_
weight_name
.
startswith
(
module
)
for
module
in
self
.
maybe_fused_weights_modules
):
for
module
in
self
.
maybe_fused_weights_modules
):
# special case for fused weights
# special case for fused weights
# get the size of each shard weight tensor
# get the size of each shard weight tensor
total_shard_sizes
=
next
(
total_shard_sizes
=
next
(
(
sizes
for
module
,
sizes
in
(
sizes
for
module
,
sizes
in
self
.
maybe_fused_weights_modules
.
items
()
self
.
maybe_fused_weights_modules
.
items
()
if
weight_name
.
startswith
(
module
)))
if
mapped_
weight_name
.
startswith
(
module
)))
total_size
=
weight_tensor
.
size
(
0
)
total_size
=
weight_tensor
.
size
(
0
)
assert
total_size
==
sum
(
total_shard_sizes
)
assert
total_size
==
sum
(
total_shard_sizes
)
# get the start/end index of each shard weight tensor
# get the start/end index of each shard weight tensor
...
@@ -1008,23 +1028,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1008,23 +1028,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_type
=
"nf4"
,
quant_type
=
"nf4"
,
)
)
quant_state_dict
[
weight_name
]
=
quant_state
quant_state_dict
[
mapped_
weight_name
]
=
quant_state
else
:
else
:
processed_weight
=
weight_tensor
processed_weight
=
weight_tensor
yield
org_weight_name
,
processed_weight
yield
weight_name
,
processed_weight
def
_get_bnb_target_modules
(
self
,
model
:
nn
.
Module
)
->
None
:
def
_get_bnb_target_modules
(
self
,
model
:
nn
.
Module
)
->
None
:
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
(
LinearBase
,
)):
if
isinstance
(
module
,
(
LinearBase
,
)):
last_name
=
name
.
split
(
"."
)[
-
1
]
if
modules_info
:
=
self
.
modules_mapping
.
get_sub_modules
(
name
):
if
sub_modules
:
=
self
.
modules_mapping
.
packed_mapping
.
get
(
last_name
,
[]):
# Map vllm's names to transformers's names.
# Map vllm's names to transformers's names.
rep_name
,
sub_modules
=
modules_info
for
sub_name
in
sub_modules
:
for
sub_name
in
sub_modules
:
self
.
target_modules
.
append
(
self
.
target_modules
.
append
(
name
.
replace
(
last
_name
,
sub_name
))
name
.
replace
(
rep
_name
,
sub_name
))
# Add original module name even if the module has stacked map,
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-splitted
# in case model has a mixture of disk-merged and disk-splitted
# weights with same last name.
# weights with same last name.
...
...
vllm/model_executor/model_loader/utils.py
View file @
96b23621
...
@@ -131,3 +131,10 @@ class ParamMapping:
...
@@ -131,3 +131,10 @@ class ParamMapping:
packed_name
,
packed_name
,
index
,
index
,
)
)
def
get_sub_modules
(
self
,
module_name
:
str
)
->
Optional
[
Tuple
[
str
,
List
[
str
]]]:
for
key
,
value
in
self
.
packed_mapping
.
items
():
if
module_name
.
endswith
(
key
):
return
key
,
value
return
None
vllm/model_executor/models/whisper.py
View file @
96b23621
...
@@ -638,6 +638,19 @@ def input_mapper_for_whisper(
...
@@ -638,6 +638,19 @@ def input_mapper_for_whisper(
@
MULTIMODAL_REGISTRY
.
register_max_multimodal_tokens
(
@
MULTIMODAL_REGISTRY
.
register_max_multimodal_tokens
(
"audio"
,
get_max_whisper_audio_tokens
)
"audio"
,
get_max_whisper_audio_tokens
)
class
WhisperForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
WhisperForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
packed_modules_mapping
=
{
"self_attn.qkv_proj"
:
[
"self_attn.q_proj"
,
"self_attn.k_proj"
,
"self_attn.v_proj"
,
],
"encoder_attn.kv_proj"
:
[
"encoder_attn.k_proj"
,
"encoder_attn.v_proj"
],
}
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_substr
=
{
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
})
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
@@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -731,10 +744,10 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"proj_out."
])
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"proj_out."
])
mapper
=
WeightsMapper
({
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
})
# add fake zeros bias for k_proj to state_dict
# add fake zeros bias for k_proj to state_dict
weights
=
_create_fake_bias_for_k_proj
(
weights
)
weights
=
_create_fake_bias_for_k_proj
(
weights
)
return
loader
.
load_weights
(
weights
,
mapper
=
mapper
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_
mapper
)
def
_create_fake_bias_for_k_proj
(
def
_create_fake_bias_for_k_proj
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment