Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0240402c
Unverified
Commit
0240402c
authored
Dec 28, 2024
by
Jee Jee Li
Committed by
GitHub
Dec 27, 2024
Browse files
[Misc]Add BNB quantization for MolmoForCausalLM (#11551)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
55509c21
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
33 deletions
+83
-33
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+18
-8
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+65
-25
No files found.
vllm/model_executor/model_loader/loader.py
View file @
0240402c
...
@@ -11,7 +11,8 @@ import os
...
@@ -11,7 +11,8 @@ import os
import
warnings
import
warnings
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
)
import
gguf
import
gguf
import
huggingface_hub
import
huggingface_hub
...
@@ -706,6 +707,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -706,6 +707,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support
# Store all module names (from transformers) that support
# BNB quantization.
# BNB quantization.
self
.
target_modules
:
List
[
str
]
=
[]
self
.
target_modules
:
List
[
str
]
=
[]
# mapping weight names from transformers to vllm.
self
.
weight_mapper
:
Callable
=
lambda
name
:
name
def
_get_weight_files
(
def
_get_weight_files
(
self
,
self
,
...
@@ -763,9 +766,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -763,9 +766,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
if
use_safetensors
:
return
safetensors_weights_iterator
(
hf_weights_files
)
iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
else
:
return
pt_weights_iterator
(
hf_weights_files
)
iterator
=
pt_weights_iterator
(
hf_weights_files
)
for
name
,
param
in
iterator
:
# mapping weight names from transformers to vllm.
yield
self
.
weight_mapper
(
name
),
param
def
_get_quantized_weights_iterator
(
def
_get_quantized_weights_iterator
(
self
,
self
,
...
@@ -782,12 +788,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -782,12 +788,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
try
:
try
:
import
bitsandbytes
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.4
4
.0"
:
if
bitsandbytes
.
__version__
<
"0.4
5
.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.4
4
.0."
)
"install bitsandbytes>=0.4
5
.0."
)
except
ImportError
as
err
:
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.4
4
.0 via "
raise
ImportError
(
"Please install bitsandbytes>=0.4
5
.0 via "
"`pip install bitsandbytes>=0.4
4
.0` to use "
"`pip install bitsandbytes>=0.4
5
.0` to use "
"bitsandbytes quantizer."
)
from
err
"bitsandbytes quantizer."
)
from
err
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
...
@@ -991,7 +997,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -991,7 +997,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
isinstance
(
module
,
(
LinearBase
,
)):
if
isinstance
(
module
,
(
LinearBase
,
)):
last_name
=
name
.
split
(
"."
)[
-
1
]
last_name
=
name
.
split
(
"."
)[
-
1
]
if
sub_modules
:
=
inverse_stacked_mapping
.
get
(
last_name
,
[]):
if
sub_modules
:
=
inverse_stacked_mapping
.
get
(
last_name
,
[]):
# Map vllm's names to transformers' names.
# Map vllm's names to transformers'
s
names.
for
sub_name
in
sub_modules
:
for
sub_name
in
sub_modules
:
self
.
target_modules
.
append
(
self
.
target_modules
.
append
(
name
.
replace
(
last_name
,
sub_name
))
name
.
replace
(
last_name
,
sub_name
))
...
@@ -1013,6 +1019,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1013,6 +1019,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet."
)
"quantization yet."
)
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
self
.
weight_mapper
=
lambda
name
:
hf_to_vllm_mapper
.
_map_name
(
name
)
# Modules whose weights might have fused on disk
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
# we need their output_sizes to make shard in flight correctly with TP
self
.
maybe_fused_weights_modules
:
Dict
[
str
,
List
[
int
]]
=
{}
self
.
maybe_fused_weights_modules
:
Dict
[
str
,
List
[
int
]]
=
{}
...
...
vllm/model_executor/models/molmo.py
View file @
0240402c
...
@@ -461,30 +461,71 @@ class MolmoAttention(nn.Module):
...
@@ -461,30 +461,71 @@ class MolmoAttention(nn.Module):
return
output
return
output
class
MolmoMLP
(
nn
.
Module
):
class
SwiGLU
(
nn
.
Module
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
gate
=
x
.
chunk
(
2
,
dim
=-
1
)
# Note that the order is reversed compared to
# SiluAndMul.
return
x
*
F
.
silu
(
gate
)
class
LanuageModelMLP
(
nn
.
Module
):
"""Molmo's LLM mlp."""
"""Molmo's LLM mlp."""
def
__init__
(
self
,
def
__init__
(
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
input_dim
:
Optional
[
int
]
=
None
,
input_dim
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
proj_name
:
str
=
"gate_up_proj"
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
//
2
self
.
intermediate_size
=
config
.
intermediate_size
//
2
# Molmo's LLM proj weights are already merged into the disk, while
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
# image_projector proj is separate. If the same proj_name were used, it
input_dim
or
self
.
hidden_size
,
# would create ambiguity and make it difficult to support BNB and LoRA.
[
self
.
intermediate_size
]
*
2
,
self
.
proj_name
=
proj_name
bias
=
False
,
setattr
(
quant_config
=
quant_config
,
self
,
proj_name
,
)
MergedColumnParallelLinear
(
# Activation function.
input_dim
or
self
.
hidden_size
,
self
.
act_fn
=
SwiGLU
()
[
self
.
intermediate_size
]
*
2
,
# Feed-forward output projection.
bias
=
False
,
self
.
down_proj
=
RowParallelLinear
(
quant_config
=
quant_config
,
self
.
intermediate_size
,
))
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
ImageProjectorMLP
(
nn
.
Module
):
"""Molmo's image_projector mlp."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
input_dim
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
//
2
self
.
merged_linear
=
MergedColumnParallelLinear
(
input_dim
or
self
.
hidden_size
,
[
self
.
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
)
# Activation function.
# Activation function.
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
...
@@ -500,7 +541,7 @@ class MolmoMLP(nn.Module):
...
@@ -500,7 +541,7 @@ class MolmoMLP(nn.Module):
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
gate_up
,
_
=
getattr
(
self
,
self
.
proj_name
)
(
x
)
gate_up
,
_
=
self
.
merged_linear
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
return
x
...
@@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module):
...
@@ -523,9 +564,7 @@ class MolmoDecoderLayer(nn.Module):
prefix
=
f
"
{
prefix
}
.self_attn"
)
prefix
=
f
"
{
prefix
}
.self_attn"
)
# MLP block.
# MLP block.
self
.
mlp
=
MolmoMLP
(
config
,
self
.
mlp
=
LanuageModelMLP
(
config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
proj_name
=
"gate_up_proj"
)
# LayerNorm
# LayerNorm
assert
config
.
layer_norm_type
==
"rms"
assert
config
.
layer_norm_type
==
"rms"
...
@@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module):
...
@@ -617,11 +656,10 @@ class MolmoVisionBackbone(nn.Module):
vision_config
,
vision_config
,
nlayers
=
len
(
self
.
vit_layers
),
nlayers
=
len
(
self
.
vit_layers
),
quant_config
=
quant_config
)
quant_config
=
quant_config
)
self
.
image_projector
=
Molmo
MLP
(
self
.
image_projector
=
ImageProjector
MLP
(
config
,
config
,
input_dim
=
vision_config
.
image_emb_dim
,
input_dim
=
vision_config
.
image_emb_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
proj_name
=
"merged_linear"
,
)
)
image_dim
=
vision_config
.
image_emb_dim
*
len
(
self
.
vit_layers
)
image_dim
=
vision_config
.
image_emb_dim
*
len
(
self
.
vit_layers
)
...
@@ -842,10 +880,6 @@ class MolmoModel(nn.Module):
...
@@ -842,10 +880,6 @@ class MolmoModel(nn.Module):
loaded_params
:
Set
[
str
]
=
set
()
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"gate_up_proj"
in
name
:
up_proj
,
gate_proj
=
loaded_weight
.
chunk
(
2
,
dim
=
0
)
loaded_weight
=
torch
.
cat
([
gate_proj
,
up_proj
],
dim
=
0
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
...
@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -1157,6 +1191,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
},
},
)
)
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{
"gate_proj"
:
(
"merged_linear"
,
0
),
"up_proj"
:
(
"merged_linear"
,
1
),
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment