Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
15cc2a9f
Unverified
Commit
15cc2a9f
authored
Nov 27, 2024
by
Jee Jee Li
Committed by
GitHub
Nov 26, 2024
Browse files
[Misc]Further reduce BNB static variable (#10597)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
e85250b1
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
131 additions
and
219 deletions
+131
-219
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+130
-88
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+0
-8
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+0
-6
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+0
-9
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+0
-9
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+0
-15
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+0
-9
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+0
-34
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+0
-14
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+0
-3
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+0
-3
vllm/model_executor/models/phi3.py
vllm/model_executor/models/phi3.py
+0
-6
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+1
-6
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+0
-9
No files found.
vllm/model_executor/model_loader/loader.py
View file @
15cc2a9f
...
@@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -78,12 +79,14 @@ def device_loading_context(module: torch.nn.Module,
...
@@ -78,12 +79,14 @@ def device_loading_context(module: torch.nn.Module,
original_device
:
torch
.
device
=
original_device_states
[
name
]
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
cpu_data
=
torch
.
empty_strided
(
stride
=
p
.
data
.
stride
(),
size
=
p
.
data
.
size
(),
dtype
=
p
.
data
.
dtype
,
stride
=
p
.
data
.
stride
(),
layout
=
p
.
data
.
layout
,
dtype
=
p
.
data
.
dtype
,
device
=
"cpu"
,
layout
=
p
.
data
.
layout
,
pin_memory
=
pin_memory
)
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
p
.
data
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
p
.
data
=
cpu_data
else
:
else
:
...
@@ -112,7 +115,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
...
@@ -112,7 +115,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
logger
.
warning
(
msg
)
logger
.
warning
(
msg
)
logger
.
warning
(
logger
.
warning
(
"Trying to guess the arguments for old-style model class %s"
,
"Trying to guess the arguments for old-style model class %s"
,
model_class
)
model_class
,
)
# try to be compatible with old-style model class
# try to be compatible with old-style model class
kwargs
=
{}
kwargs
=
{}
if
"prefix"
in
all_params
:
if
"prefix"
in
all_params
:
...
@@ -198,14 +202,17 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -198,14 +202,17 @@ class DefaultModelLoader(BaseModelLoader):
return
model_path
return
model_path
return
None
return
None
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
def
_prepare_weights
(
revision
:
Optional
[
str
],
self
,
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
,
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
If the model is not local, it will be downloaded."""
model_name_or_path
=
self
.
_maybe_download_from_modelscope
(
model_name_or_path
=
(
self
.
_maybe_download_from_modelscope
(
model_name_or_path
,
revision
)
or
model_name_or_path
model_name_or_path
,
revision
)
or
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
load_format
=
self
.
load_config
.
load_format
...
@@ -258,8 +265,11 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -258,8 +265,11 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index.
# any files not found in the index.
if
not
is_local
:
if
not
is_local
:
download_safetensors_index_file_from_hf
(
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
model_name_or_path
,
self
.
load_config
.
download_dir
,
revision
)
index_file
,
self
.
load_config
.
download_dir
,
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
hf_weights_files
,
hf_folder
,
index_file
)
else
:
else
:
...
@@ -282,8 +292,11 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -282,8 +292,11 @@ class DefaultModelLoader(BaseModelLoader):
# Currently np_cache only support *.bin checkpoints
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
weights_iterator
=
np_cache_weights_iterator
(
source
.
model_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
source
.
model_or_path
,
hf_weights_files
)
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
,
)
elif
use_safetensors
:
elif
use_safetensors
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
else
:
...
@@ -310,17 +323,19 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -310,17 +323,19 @@ class DefaultModelLoader(BaseModelLoader):
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
,
model
:
nn
.
Module
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
primary_weights
=
DefaultModelLoader
.
Source
(
primary_weights
=
DefaultModelLoader
.
Source
(
model_config
.
model
,
model_config
.
model
,
model_config
.
revision
,
model_config
.
revision
,
prefix
=
""
,
prefix
=
""
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
))
True
),
)
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
secondary_weights
=
cast
(
Iterable
[
DefaultModelLoader
.
Source
],
secondary_weights
=
cast
(
getattr
(
model
,
"secondary_weights"
,
()))
Iterable
[
DefaultModelLoader
.
Source
],
getattr
(
model
,
"secondary_weights"
,
()),
)
for
source
in
secondary_weights
:
for
source
in
secondary_weights
:
yield
from
self
.
_get_weights_iterator
(
source
)
yield
from
self
.
_get_weights_iterator
(
source
)
...
@@ -416,7 +431,7 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -416,7 +431,7 @@ class TensorizerLoader(BaseModelLoader):
self
.
tensorizer_config
.
verify_with_parallel_config
(
parallel_config
)
self
.
tensorizer_config
.
verify_with_parallel_config
(
parallel_config
)
def
_get_weights_iterator
(
def
_get_weights_iterator
(
self
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
self
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
tensorizer_args
=
self
.
tensorizer_config
.
_construct_tensorizer_args
()
tensorizer_args
=
self
.
tensorizer_config
.
_construct_tensorizer_args
()
return
tensorizer_weights_iterator
(
tensorizer_args
)
return
tensorizer_weights_iterator
(
tensorizer_args
)
...
@@ -479,9 +494,10 @@ class TensorizerLoader(BaseModelLoader):
...
@@ -479,9 +494,10 @@ class TensorizerLoader(BaseModelLoader):
if
parallel_config
.
tensor_parallel_size
>
1
:
if
parallel_config
.
tensor_parallel_size
>
1
:
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
self
.
tensorizer_config
.
tensorizer_uri
=
\
self
.
tensorizer_config
.
tensorizer_uri
\
self
.
tensorizer_config
.
tensorizer_uri
=
(
%
get_tensor_model_parallel_rank
()
self
.
tensorizer_config
.
tensorizer_uri
%
get_tensor_model_parallel_rank
())
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
vllm_config
=
vllm_config
)
return
self
.
_load_model_serialized
(
vllm_config
=
vllm_config
)
...
@@ -520,13 +536,13 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -520,13 +536,13 @@ class ShardedStateLoader(BaseModelLoader):
@
staticmethod
@
staticmethod
def
_filter_subtensors
(
def
_filter_subtensors
(
tensors
:
Dict
[
str
,
torch
.
Tensor
])
->
Dict
[
str
,
torch
.
Tensor
]:
tensors
:
Dict
[
str
,
torch
.
Tensor
]
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
"""
Filter out all tensors that share the same memory or a subset of the
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
memory of another tensor.
"""
"""
same_storage_groups
:
Dict
[
Any
,
List
[
Tuple
[
same_storage_groups
:
Dict
[
Any
,
List
[
Tuple
[
str
,
torch
.
Tensor
]]]
=
(
str
,
torch
.
Tensor
]]]
=
collections
.
defaultdict
(
list
)
collections
.
defaultdict
(
list
)
)
for
key
,
tensor
in
tensors
.
items
():
for
key
,
tensor
in
tensors
.
items
():
if
tensor
.
numel
():
if
tensor
.
numel
():
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
...
@@ -615,8 +631,11 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -615,8 +631,11 @@ class ShardedStateLoader(BaseModelLoader):
if
tensor
.
shape
!=
param_shape
:
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
logger
.
warning
(
"loading tensor of shape %s into "
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
"parameter '%s' of shape %s"
,
key
,
param_shape
)
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
state_dict
.
pop
(
key
)
if
state_dict
:
if
state_dict
:
...
@@ -634,6 +653,7 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -634,6 +653,7 @@ class ShardedStateLoader(BaseModelLoader):
from
safetensors.torch
import
save_file
from
safetensors.torch
import
save_file
from
vllm.distributed
import
get_tensor_model_parallel_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
if
pattern
is
None
:
if
pattern
is
None
:
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
rank
=
get_tensor_model_parallel_rank
()
rank
=
get_tensor_model_parallel_rank
()
...
@@ -667,24 +687,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -667,24 +687,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
possible_config_file_names
=
[
"adapter_config.json"
]
possible_config_file_names
=
[
"adapter_config.json"
]
default_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
'.fc1.'
,
'.fc2.'
,
'.dense.'
,
'.query_key_value.'
,
'.qkv_proj.'
,
'.dense_h_to_4h.'
,
'.dense_4h_to_h.'
,
'.out_proj.'
,
]
def
__init__
(
self
,
load_config
:
LoadConfig
):
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
super
().
__init__
(
load_config
)
...
@@ -709,6 +711,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -709,6 +711,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
with
open
(
config_file_path
)
as
f
:
with
open
(
config_file_path
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
self
.
target_modules
=
config
[
"target_modules"
]
self
.
target_modules
=
config
[
"target_modules"
]
# TODO: target_modules could be either a list or a regex string.
# We need to handle both cases.
assert
isinstance
(
self
.
target_modules
,
list
),
"Unsupported target_modules: "
f
"
{
self
.
target_modules
}
"
def
_get_config_file
(
self
,
qlora_adapter
:
str
)
->
str
:
def
_get_config_file
(
self
,
qlora_adapter
:
str
)
->
str
:
is_local
=
os
.
path
.
isdir
(
qlora_adapter
)
is_local
=
os
.
path
.
isdir
(
qlora_adapter
)
...
@@ -734,12 +741,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -734,12 +741,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return
config_file_path
return
config_file_path
def
_get_weight_files
(
def
_get_weight_files
(
self
,
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
allowed_patterns
:
List
[
str
],
allowed_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
)
->
Tuple
[
List
[
str
],
str
]:
revision
:
Optional
[
str
]
=
None
,
"""Retrieve weight files. Download the files if necessary.
)
->
Tuple
[
List
[
str
],
str
]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
Return the weight files and the file pattern."""
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
...
@@ -806,6 +814,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -806,6 +814,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
# only load the bitsandbytes module when needed
try
:
try
:
import
bitsandbytes
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.44.0"
:
if
bitsandbytes
.
__version__
<
"0.44.0"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0."
)
"install bitsandbytes>=0.44.0."
)
...
@@ -839,8 +848,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -839,8 +848,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
_is_4bit_weight_name
(
self
,
weight_name
:
str
):
def
_is_4bit_weight_name
(
self
,
weight_name
:
str
):
quantized_suffix
=
{
quantized_suffix
=
{
"absmax"
,
"quant_map"
,
"nested_absmax"
,
"nested_quant_map"
,
"absmax"
,
"bitsandbytes"
"quant_map"
,
"nested_absmax"
,
"nested_quant_map"
,
"bitsandbytes"
,
}
}
suffix
=
weight_name
.
split
(
"."
)[
-
1
]
suffix
=
weight_name
.
split
(
"."
)[
-
1
]
return
any
(
q_suffix
in
suffix
for
q_suffix
in
quantized_suffix
)
return
any
(
q_suffix
in
suffix
for
q_suffix
in
quantized_suffix
)
...
@@ -857,7 +869,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -857,7 +869,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
hf_weights_files
,
use_safetensors
):
if
self
.
_is_8bit_weight_name
(
weight_name
):
if
self
.
_is_8bit_weight_name
(
weight_name
):
continue
continue
...
@@ -899,14 +910,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -899,14 +910,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# pre quantized weights would have a quant_state
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
hf_weights_files
,
use_safetensors
):
if
self
.
_is_4bit_weight_name
(
weight_name
):
if
self
.
_is_4bit_weight_name
(
weight_name
):
continue
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
\
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
in
temp_state_dict
)
or
\
in
temp_state_dict
)
or
(
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
\
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
in
temp_state_dict
):
in
temp_state_dict
):
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
quant_state
=
_parse_quant_state
(
weight_name
,
temp_state_dict
)
quant_state_dict
[
weight_name
]
=
quant_state
quant_state_dict
[
weight_name
]
=
quant_state
yield
weight_name
,
weight_tensor
yield
weight_name
,
weight_tensor
...
@@ -916,12 +926,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -916,12 +926,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
quantize_4bit
from
bitsandbytes.functional
import
quantize_4bit
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
# Without sharding
# Without sharding
...
@@ -954,12 +964,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -954,12 +964,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# get the start/end index of each shard weight tensor
# get the start/end index of each shard weight tensor
total_start_index
=
list
(
total_start_index
=
list
(
itertools
.
accumulate
([
0
]
+
total_shard_sizes
))[:
-
1
]
itertools
.
accumulate
([
0
]
+
total_shard_sizes
))[:
-
1
]
shard_weights_index
=
[
shard_weights_index
=
[(
(
idx
+
size
//
tp_size
*
tp_rank
,
idx
+
size
//
tp_size
*
tp_rank
,
idx
+
size
//
tp_size
*
(
tp_rank
+
1
))
idx
+
size
//
tp_size
*
(
tp_rank
+
1
),
for
idx
,
size
in
zip
(
total_start_index
,
)
for
idx
,
size
in
zip
(
total_start_index
,
total_shard_sizes
)
total_shard_sizes
)]
]
# slice and reorder the weight tensor
# slice and reorder the weight tensor
weight_tensor
=
[
weight_tensor
=
[
weight_tensor
[
start_index
:
end_index
,
...]
weight_tensor
[
start_index
:
end_index
,
...]
...
@@ -989,7 +998,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -989,7 +998,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
processed_weight
,
quant_state
=
quantize_4bit
(
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
loaded_weight
,
compress_statistics
=
True
,
compress_statistics
=
True
,
quant_type
=
"nf4"
)
quant_type
=
"nf4"
,
)
quant_state_dict
[
weight_name
]
=
quant_state
quant_state_dict
[
weight_name
]
=
quant_state
else
:
else
:
...
@@ -997,28 +1007,58 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -997,28 +1007,58 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield
weight_name
,
processed_weight
yield
weight_name
,
processed_weight
def
_get_bnb_target_modules
(
self
,
model
:
nn
.
Module
)
->
None
:
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
# packed_modules_mapping.
inverse_stacked_mapping
:
Dict
[
str
,
List
[
str
]]
=
{}
for
orig
,
(
packed
,
idx
,
)
in
model
.
bitsandbytes_stacked_params_mapping
.
items
():
if
packed
not
in
inverse_stacked_mapping
:
inverse_stacked_mapping
[
packed
]
=
[]
inverse_stacked_mapping
[
packed
].
insert
(
idx
,
orig
)
linear_module_lst
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
(
LinearBase
,
)):
last_name
=
name
.
split
(
"."
)[
-
1
]
if
sub_modules
:
=
inverse_stacked_mapping
.
get
(
last_name
,
[]):
# Map vllm's names to transformers' names.
for
sub_name
in
sub_modules
:
linear_module_lst
.
append
(
name
.
replace
(
last_name
,
sub_name
))
else
:
linear_module_lst
.
append
(
name
)
if
self
.
target_modules
:
# Update self.target_modules
self
.
target_modules
=
[
qual_name
for
qual_name
in
linear_module_lst
if
any
(
t
in
qual_name
for
t
in
self
.
target_modules
)
]
else
:
self
.
target_modules
=
linear_module_lst
assert
(
self
.
target_modules
),
"vllm currently does not support BNB quantization for"
f
"
{
type
(
model
).
__name__
}
"
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
)
->
None
:
model
:
nn
.
Module
)
->
None
:
if
not
hasattr
(
model
,
'
load_weights
'
):
if
not
hasattr
(
model
,
"
load_weights
"
):
raise
AttributeError
(
raise
AttributeError
(
"The required method 'load_weights' is not defined in class"
"The required method 'load_weights' is not defined in class"
f
"
{
type
(
model
).
__name__
}
."
)
f
"
{
type
(
model
).
__name__
}
."
)
if
not
hasattr
(
model
,
'
bitsandbytes_stacked_params_mapping
'
):
if
not
hasattr
(
model
,
"
bitsandbytes_stacked_params_mapping
"
):
raise
AttributeError
(
raise
AttributeError
(
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet."
)
"quantization yet."
)
if
len
(
self
.
target_modules
)
==
0
:
if
hasattr
(
model
,
'default_bitsandbytes_target_modules'
):
self
.
target_modules
=
model
.
default_bitsandbytes_target_modules
else
:
self
.
target_modules
=
self
.
default_target_modules
# Modules whose weights might have fused on disk
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
# we need their output_sizes to make shard in flight correctly with TP
self
.
maybe_fused_weights_modules
:
Dict
[
str
,
List
[
int
]]
=
{}
self
.
maybe_fused_weights_modules
:
Dict
[
str
,
List
[
int
]]
=
{}
self
.
_get_bnb_target_modules
(
model
)
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
# Some modules like `ReplicatedLinear` should not have their weights
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# sharded. The reason for implementing it this way is to avoid new
...
@@ -1046,7 +1086,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1046,7 +1086,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pre_quant
=
False
pre_quant
=
False
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get
(
'
quant_method
'
)
quant_method
=
quant_config
.
get
(
"
quant_method
"
)
if
quant_method
==
"bitsandbytes"
:
if
quant_method
==
"bitsandbytes"
:
pre_quant
=
True
pre_quant
=
True
else
:
else
:
...
@@ -1063,11 +1103,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1063,11 +1103,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
load_8bit
=
False
load_8bit
=
False
if
pre_quant
:
if
pre_quant
:
load_8bit
=
quant_config
.
get
(
'
load_in_8bit
'
,
False
)
load_8bit
=
quant_config
.
get
(
"
load_in_8bit
"
,
False
)
qweight_iterator
,
quant_state_dict
=
\
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_quantized_weights_iterator
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
model
,
model_config
.
revision
,
pre_quant
,
load_8bit
)
model_config
.
revision
,
pre_quant
,
load_8bit
))
model
.
load_weights
(
qweight_iterator
)
model
.
load_weights
(
qweight_iterator
)
...
@@ -1078,6 +1119,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1078,6 +1119,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# TODO: Change this lazy import to normal import
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
# after the checks are updated to run on a new version
from
vllm.model_executor.models.utils
import
is_pp_missing_parameter
from
vllm.model_executor.models.utils
import
is_pp_missing_parameter
for
quant_param_name
in
quant_state_dict
:
for
quant_param_name
in
quant_state_dict
:
if
is_pp_missing_parameter
(
quant_param_name
,
model
):
if
is_pp_missing_parameter
(
quant_param_name
,
model
):
continue
continue
...
@@ -1086,9 +1128,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1086,9 +1128,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
shard_index
=
0
shard_index
=
0
for
shard_name
,
(
for
shard_name
,
(
weight_name
,
index
weight_name
,
index
,
)
in
model
.
bitsandbytes_stacked_params_mapping
.
items
():
)
in
model
.
bitsandbytes_stacked_params_mapping
.
items
():
shard_pos
=
quant_param_name
.
find
(
shard_name
)
shard_pos
=
quant_param_name
.
find
(
shard_name
)
# Some models, such as MiniCPM V2.5/2.6, contain both
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
...
@@ -1123,8 +1165,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1123,8 +1165,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
num_elements
=
[
0
]
*
len
(
quant_states
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
quant_states
.
items
():
for
seq
,
quant_state
in
quant_states
.
items
():
num_elements
[
seq
]
=
math
.
prod
(
num_elements
[
seq
]
=
(
math
.
prod
(
quant_state
.
shape
)
//
quant_state
.
shape
)
//
pack_ratio
pack_ratio
)
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
...
...
vllm/model_executor/models/baichuan.py
View file @
15cc2a9f
...
@@ -351,14 +351,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -351,14 +351,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".W_pack."
,
".o_proj."
,
".down_proj."
,
".up_proj."
,
".gate_proj."
,
".up_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
...
...
vllm/model_executor/models/falcon.py
View file @
15cc2a9f
...
@@ -412,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
...
@@ -412,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{}
bitsandbytes_stacked_params_mapping
=
{}
default_bitsandbytes_target_modules
=
[
".query_key_value."
,
".dense."
,
".dense_h_to_4h."
,
".dense_4h_to_h."
,
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
...
vllm/model_executor/models/gemma.py
View file @
15cc2a9f
...
@@ -350,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -350,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"down_proj"
,
"down_proj"
,
]
]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/gemma2.py
View file @
15cc2a9f
...
@@ -386,15 +386,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -386,15 +386,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/idefics3.py
View file @
15cc2a9f
...
@@ -656,21 +656,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -656,21 +656,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
]
]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
# vision_model
".fc1."
,
".fc2."
,
".out_proj."
,
# connector
".proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/llama.py
View file @
15cc2a9f
...
@@ -463,15 +463,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -463,15 +463,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules
=
[
"lm_head"
]
embedding_padding_modules
=
[
"lm_head"
]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/minicpmv.py
View file @
15cc2a9f
...
@@ -822,25 +822,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -822,25 +822,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
]
]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
# vision encoder
".fc1."
,
".fc2."
,
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj."
,
# vision encoder out_proj
# resampler
".kv_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
@@ -964,21 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -964,21 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
]
]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
# vision encoder
".fc1."
,
".fc2."
,
".self_attn.out_proj."
,
# resampler
".kv_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/mllama.py
View file @
15cc2a9f
...
@@ -1104,20 +1104,6 @@ class MllamaForCausalLM(nn.Module):
...
@@ -1104,20 +1104,6 @@ class MllamaForCausalLM(nn.Module):
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_mllama
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_mllama
)
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
".fc1."
,
".fc2."
,
# The `multi_modal_projector` is at the top level of the model,
# so we can't add a dot in front of it.
"multi_modal_projector."
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/opt.py
View file @
15cc2a9f
...
@@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
...
@@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
"k_proj"
:
(
"qkv_proj"
,
1
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"v_proj"
:
(
"qkv_proj"
,
2
),
}
}
default_bitsandbytes_target_modules
=
[
".q_proj."
,
".k_proj."
,
".v_proj."
,
".out_proj."
,
".fc1."
,
".fc2."
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
...
...
vllm/model_executor/models/phi.py
View file @
15cc2a9f
...
@@ -286,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -286,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"k_proj"
:
(
"qkv_proj"
,
1
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"v_proj"
:
(
"qkv_proj"
,
2
),
}
}
default_bitsandbytes_target_modules
=
[
".q_proj."
,
".k_proj."
,
".v_proj."
,
".fc1."
,
".fc2."
,
".dense."
]
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
...
...
vllm/model_executor/models/phi3.py
View file @
15cc2a9f
...
@@ -16,11 +16,5 @@ class Phi3ForCausalLM(LlamaForCausalLM):
...
@@ -16,11 +16,5 @@ class Phi3ForCausalLM(LlamaForCausalLM):
}
}
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_up_proj."
,
".down_proj."
,
".qkv_proj."
,
".o_proj."
,
]
# Initialize an empty dict when there is no stacked parameter mapping.
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping
=
{}
bitsandbytes_stacked_params_mapping
=
{}
vllm/model_executor/models/qwen.py
View file @
15cc2a9f
...
@@ -1028,12 +1028,7 @@ class QWenLLM(QWenBaseModel):
...
@@ -1028,12 +1028,7 @@ class QWenLLM(QWenBaseModel):
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
default_bitsandbytes_target_modules
=
[
# BitandBytes specific attributes
".c_attn."
,
".c_proj."
,
".w1."
,
".w2."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"w2"
:
(
"gate_up_proj"
,
0
),
"w2"
:
(
"gate_up_proj"
,
0
),
...
...
vllm/model_executor/models/qwen2.py
View file @
15cc2a9f
...
@@ -419,15 +419,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -419,15 +419,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
# BitandBytes specific attributes
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment