Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9c64c0c
Unverified
Commit
b9c64c0c
authored
Nov 06, 2024
by
Jee Jee Li
Committed by
GitHub
Nov 05, 2024
Browse files
[Misc] Modify BNB parameter name (#9997)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
d2e80332
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
14 deletions
+11
-14
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+5
-4
vllm/model_executor/layers/resampler.py
vllm/model_executor/layers/resampler.py
+1
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+5
-9
No files found.
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
b9c64c0c
...
...
@@ -203,8 +203,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
qweight
=
create_qweight_for_8bit
()
else
:
qweight
=
create_qweight_for_4bit
()
layer
.
register_parameter
(
"qweight"
,
qweight
)
# Enable parameters to have the same name as in the BNB
# checkpoint format.
layer
.
register_parameter
(
"weight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
def
apply
(
self
,
...
...
@@ -234,7 +235,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
reshape_after_matmul
=
True
bf_x
=
x
.
to
(
torch
.
bfloat16
)
qweight
=
layer
.
q
weight
qweight
=
layer
.
weight
offsets
=
qweight
.
bnb_shard_offsets
quant_states
=
qweight
.
bnb_quant_state
matmul_states
=
qweight
.
matmul_state
...
...
@@ -313,7 +314,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
reshape_after_matmul
=
True
bf_x
=
x
.
to
(
torch
.
bfloat16
)
qweight
=
layer
.
q
weight
qweight
=
layer
.
weight
quant_states
=
qweight
.
bnb_quant_state
offsets
=
qweight
.
bnb_shard_offsets
...
...
vllm/model_executor/layers/resampler.py
View file @
b9c64c0c
...
...
@@ -177,7 +177,7 @@ class BaseResampler(nn.Module):
embed_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
f
"
{
prefix
}
.kv_proj"
)
else
:
# Maintain the same return value with ReplicatedLinear.forward
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
# type: ignore # noqa
...
...
vllm/model_executor/model_loader/loader.py
View file @
b9c64c0c
...
...
@@ -892,7 +892,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
not
weight_name
.
lower
().
endswith
(
".scb"
):
continue
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".
q
weight"
)
weight_key
=
weight_name
.
lower
().
replace
(
".scb"
,
".weight"
)
quant_state_dict
[
weight_key
]
=
weight_tensor
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
...
...
@@ -901,11 +901,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
self
.
_is_8bit_weight_name
(
weight_name
):
continue
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
if
qweight_name
in
quant_state_dict
:
if
weight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
q
weight_name
,
weight_tensor
yield
weight_name
,
weight_tensor
else
:
yield
weight_name
,
weight_tensor
...
...
@@ -950,9 +948,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
\
in
temp_state_dict
):
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
.
replace
(
".weight"
,
".qweight"
)
,
weight_tensor
yield
weight_name
,
weight_tensor
else
:
yield
weight_name
,
weight_tensor
...
...
@@ -967,7 +964,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# Without sharding
if
any
(
weight_name
.
startswith
(
module
)
...
...
@@ -1093,7 +1089,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.
q
weight
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
if
shard_pos
>
0
and
quant_param_name
[
shard_pos
-
1
]
==
"."
:
shard_index
=
index
quant_param_name
=
quant_param_name
.
replace
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment