Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
09500f7d
Unverified
Commit
09500f7d
authored
Oct 29, 2024
by
Isotr0py
Committed by
GitHub
Oct 29, 2024
Browse files
[Model] Add BNB quantization support for Mllama (#9720)
parent
ef7865b4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
12 deletions
+84
-12
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+31
-4
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+16
-3
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+37
-5
No files found.
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
09500f7d
...
...
@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
import
torch
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
...
...
@@ -23,7 +24,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant
:
bool
=
False
,
llm_int8_enable_fp32_cpu_offload
:
bool
=
False
,
llm_int8_has_fp16_weight
:
bool
=
False
,
llm_int8_skip_modules
:
Optional
[
Any
]
=
None
,
llm_int8_skip_modules
:
Optional
[
List
[
str
]
]
=
None
,
llm_int8_threshold
:
float
=
0.0
,
)
->
None
:
...
...
@@ -34,11 +35,15 @@ class BitsAndBytesConfig(QuantizationConfig):
self
.
bnb_4bit_use_double_quant
=
bnb_4bit_use_double_quant
self
.
llm_int8_enable_fp32_cpu_offload
=
llm_int8_enable_fp32_cpu_offload
self
.
llm_int8_has_fp16_weight
=
llm_int8_has_fp16_weight
self
.
llm_int8_skip_modules
=
llm_int8_skip_modules
self
.
llm_int8_skip_modules
=
llm_int8_skip_modules
or
[]
self
.
llm_int8_threshold
=
llm_int8_threshold
def
__repr__
(
self
)
->
str
:
return
"BitsAndBytesConfig"
return
(
f
"BitsAndBytesConfig(load_in_8bit=
{
self
.
load_in_8bit
}
, "
f
"load_in_4bit=
{
self
.
load_in_4bit
}
, "
f
"bnb_4bit_compute_dtype=
{
self
.
bnb_4bit_compute_dtype
}
, "
f
"bnb_4bit_quant_type=
{
self
.
bnb_4bit_quant_type
}
, "
f
"llm_int8_skip_modules=
{
self
.
llm_int8_skip_modules
}
)"
)
@
classmethod
def
get_name
(
self
)
->
str
:
...
...
@@ -102,8 +107,10 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_threshold
=
llm_int8_threshold
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"
BitsAndBytes
LinearMethod"
]:
prefix
:
str
)
->
Optional
[
"LinearMethod
Base
"
]:
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped_bnb
(
prefix
,
self
.
llm_int8_skip_modules
):
return
UnquantizedLinearMethod
()
return
BitsAndBytesLinearMethod
(
self
)
return
None
...
...
@@ -111,6 +118,10 @@ class BitsAndBytesConfig(QuantizationConfig):
return
[]
def
is_layer_skipped_bnb
(
prefix
:
str
,
llm_int8_skip_modules
:
List
[
str
]):
return
any
(
module_name
in
prefix
for
module_name
in
llm_int8_skip_modules
)
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
"""Linear method for BitsAndBytes.
...
...
@@ -211,6 +222,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
from
bitsandbytes
import
MatmulLtState
,
matmul
original_type
=
x
.
dtype
original_shape
=
x
.
shape
reshape_after_matmul
=
False
if
x
.
ndim
>
2
:
x
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
reshape_after_matmul
=
True
bf_x
=
x
.
to
(
torch
.
bfloat16
)
qweight
=
layer
.
qweight
...
...
@@ -265,6 +281,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out
=
out
.
to
(
original_type
)
if
reshape_after_matmul
:
out
=
out
.
view
(
*
original_shape
[:
-
1
],
out
.
size
(
-
1
))
if
bias
is
not
None
:
out
+=
bias
...
...
@@ -282,6 +301,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
from
bitsandbytes
import
matmul_4bit
original_type
=
x
.
dtype
original_shape
=
x
.
shape
reshape_after_matmul
=
False
if
x
.
ndim
>
2
:
x
=
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
reshape_after_matmul
=
True
bf_x
=
x
.
to
(
torch
.
bfloat16
)
qweight
=
layer
.
qweight
...
...
@@ -310,6 +334,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out
=
out
.
to
(
original_type
)
if
reshape_after_matmul
:
out
=
out
.
view
(
*
original_shape
[:
-
1
],
out
.
size
(
-
1
))
if
bias
is
not
None
:
out
+=
bias
...
...
vllm/model_executor/model_loader/loader.py
View file @
09500f7d
...
...
@@ -899,6 +899,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
self
.
_unquantized_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
def
_is_8bit_weight_name
(
self
,
weight_name
:
str
):
quantized_suffix
=
{
".scb"
,
".weight_format"
}
return
any
(
weight_name
.
lower
().
endswith
(
suffix
)
for
suffix
in
quantized_suffix
)
def
_is_4bit_weight_name
(
self
,
weight_name
:
str
):
quantized_suffix
=
{
"absmax"
,
"quant_map"
,
"nested_absmax"
,
"nested_quant_map"
,
"bitsandbytes"
}
suffix
=
weight_name
.
split
(
"."
)[
-
1
]
return
any
(
q_suffix
in
suffix
for
q_suffix
in
quantized_suffix
)
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
...
...
@@ -912,7 +925,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
endswith
((
".weight"
,
".bias"
)
):
if
self
.
_is_8bit_weight_name
(
weight_name
):
continue
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
...
...
@@ -932,7 +945,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
use_safetensors
)
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
((
".weight"
,
".bias"
)
):
if
not
self
.
_is_4bit_weight_name
(
weight_name
):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
...
...
@@ -956,7 +969,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
endswith
((
".weight"
,
".bias"
)
):
if
self
.
_is_4bit_weight_name
(
weight_name
):
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
\
...
...
vllm/model_executor/models/mllama.py
View file @
09500f7d
...
...
@@ -325,7 +325,10 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
# TODO: support other attention backends for attention in vision model
class
MllamaVisionSdpaAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
model_parallel_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -341,12 +344,16 @@ class MllamaVisionSdpaAttention(nn.Module):
self
.
head_dim
,
self
.
num_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
embed_dim
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
def
forward
(
...
...
@@ -393,7 +400,8 @@ class MllamaVisionEncoderLayer(nn.Module):
self
.
is_gated
=
is_gated
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
MllamaVisionSdpaAttention
(
config
)
self
.
self_attn
=
MllamaVisionSdpaAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
mlp
=
CLIPMLP
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
...
...
@@ -1002,6 +1010,7 @@ class MllamaForCausalLM(nn.Module):
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.lm_head"
,
)
def
forward
(
...
...
@@ -1037,6 +1046,26 @@ class MllamaForCausalLM(nn.Module):
@
INPUT_REGISTRY
.
register_dummy_encoder_data
(
dummy_encoder_data_for_mllama
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_mllama
)
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
self
,
config
:
config_mllama
.
MllamaConfig
,
...
...
@@ -1061,10 +1090,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config
=
quant_config
,
prefix
=
"language_model"
,
)
self
.
multi_modal_projector
=
nn
.
Linear
(
self
.
multi_modal_projector
=
ColumnParallel
Linear
(
config
.
vision_config
.
vision_output_dim
,
config
.
text_config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
gather_output
=
True
,
prefix
=
"multi_modal_projector"
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
output_hidden_states
,
config
.
text_config
.
vocab_size
)
...
...
@@ -1128,7 +1160,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"No images provided."
)
max_num_tiles
=
max
(
max
([
len
(
x
)
for
x
in
y
[
0
]])
for
y
in
pixel_values
)
device
=
self
.
multi_modal_projector
.
weight
.
device
device
=
next
(
self
.
multi_modal_projector
.
parameters
())
.
device
bsz
=
len
(
pixel_values
)
out_num_tiles
=
[]
out_images
=
torch
.
zeros
(
...
...
@@ -1204,7 +1236,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
=
self
.
multi_modal_projector
(
cross_attention_states
,
_
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment