Unverified Commit 1f1542af authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Misc]Add BNB quantization for PaliGemmaForConditionalGeneration (#12237)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 96912550
......@@ -136,6 +136,17 @@ class PaliGemmaMultiModalProjector(nn.Module):
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
......@@ -344,10 +344,16 @@ class SiglipMLP(nn.Module):
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# For quantization, we require the hidden size to be a multiple of 64
quantizable = (config.hidden_size % 64 == 0
and config.intermediate_size % 64 == 0)
# Special handling for BNB quantization
if quant_config and quant_config.get_name() == "bitsandbytes":
quantizable = True
else:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable = (
config.hidden_size % 64 == 0
and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment