Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1f1542af
Unverified
Commit
1f1542af
authored
Jan 21, 2025
by
Jee Jee Li
Committed by
GitHub
Jan 21, 2025
Browse files
[Misc]Add BNB quantization for PaliGemmaForConditionalGeneration (#12237)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
96912550
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
5 deletions
+22
-5
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+12
-1
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+10
-4
No files found.
vllm/model_executor/models/paligemma.py
View file @
1f1542af
...
...
@@ -136,7 +136,18 @@ class PaliGemmaMultiModalProjector(nn.Module):
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_paligemma
)
class
PaliGemmaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
...
...
vllm/model_executor/models/siglip.py
View file @
1f1542af
...
...
@@ -344,10 +344,16 @@ class SiglipMLP(nn.Module):
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
# For quantization, we require the hidden size to be a multiple of 64
quantizable
=
(
config
.
hidden_size
%
64
==
0
and
config
.
intermediate_size
%
64
==
0
)
# Special handling for BNB quantization
if
quant_config
and
quant_config
.
get_name
()
==
"bitsandbytes"
:
quantizable
=
True
else
:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable
=
(
config
.
hidden_size
%
64
==
0
and
config
.
intermediate_size
%
64
==
0
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment