Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c49f0407
Unverified
Commit
c49f0407
authored
Nov 04, 2024
by
Jee Jee Li
Committed by
GitHub
Nov 04, 2024
Browse files
[Bugfix] Fix MiniCPMV and Mllama BNB bug (#9917)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
91c9ebbb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
145 additions
and
65 deletions
+145
-65
vllm/model_executor/layers/resampler.py
vllm/model_executor/layers/resampler.py
+28
-21
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+28
-6
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+83
-37
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+6
-1
No files found.
vllm/model_executor/layers/resampler.py
View file @
c49f0407
...
@@ -41,6 +41,7 @@ from torch import nn
...
@@ -41,6 +41,7 @@ from torch import nn
from
torch.nn.init
import
trunc_normal_
from
torch.nn.init
import
trunc_normal_
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
...
@@ -154,15 +155,15 @@ class BaseResampler(nn.Module):
...
@@ -154,15 +155,15 @@ class BaseResampler(nn.Module):
A tensor with the shape of (grid_size**2, embed_dim)
A tensor with the shape of (grid_size**2, embed_dim)
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
num_queries
:
int
,
num_queries
:
int
,
embed_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
do_post_projection
:
bool
=
True
,
do_post_projection
:
bool
=
True
,
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
num_queries
=
num_queries
self
.
num_queries
=
num_queries
...
@@ -172,7 +173,11 @@ class BaseResampler(nn.Module):
...
@@ -172,7 +173,11 @@ class BaseResampler(nn.Module):
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
self
.
query
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_queries
,
embed_dim
))
trunc_normal_
(
self
.
query
,
std
=
0.02
)
trunc_normal_
(
self
.
query
,
std
=
0.02
)
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
if
kv_dim
is
not
None
and
kv_dim
!=
embed_dim
:
self
.
kv_proj
=
ReplicatedLinear
(
kv_dim
,
embed_dim
,
bias
=
False
)
self
.
kv_proj
=
ReplicatedLinear
(
kv_dim
,
embed_dim
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
prefix
)
else
:
else
:
# Maintain the same return value with ReplicatedLinear.forward
# Maintain the same return value with ReplicatedLinear.forward
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
# type: ignore # noqa
self
.
kv_proj
=
lambda
*
args
,
**
kwargs
:
(
# type: ignore # noqa
...
@@ -209,8 +214,7 @@ class Resampler2(BaseResampler):
...
@@ -209,8 +214,7 @@ class Resampler2(BaseResampler):
present in minicpmv2.0, but not qwen-vl.
present in minicpmv2.0, but not qwen-vl.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
grid_size
:
int
,
grid_size
:
int
,
embed_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
num_heads
:
int
,
...
@@ -218,13 +222,16 @@ class Resampler2(BaseResampler):
...
@@ -218,13 +222,16 @@ class Resampler2(BaseResampler):
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
adaptive
:
bool
=
False
,
adaptive
:
bool
=
False
,
do_post_projection
:
bool
=
True
,
do_post_projection
:
bool
=
True
,
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
(
grid_size
**
2
,
super
().
__init__
(
grid_size
**
2
,
embed_dim
,
embed_dim
,
num_heads
,
num_heads
,
kv_dim
,
kv_dim
,
norm_layer
,
norm_layer
,
do_post_projection
=
do_post_projection
)
do_post_projection
=
do_post_projection
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
adaptive
=
adaptive
self
.
adaptive
=
adaptive
pos_embed_arr
=
get_2d_sincos_pos_embed
(
embed_dim
,
pos_embed_arr
=
get_2d_sincos_pos_embed
(
embed_dim
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
c49f0407
...
@@ -28,6 +28,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -28,6 +28,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.model_loader.tensorizer
import
(
from
vllm.model_executor.model_loader.tensorizer
import
(
...
@@ -771,6 +772,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -771,6 +772,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
with
open
(
config_file_path
,
"r"
)
as
f
:
with
open
(
config_file_path
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
self
.
target_modules
=
config
[
"target_modules"
]
self
.
target_modules
=
config
[
"target_modules"
]
# Save the module names without sharding.
self
.
unsharded_weights_modules
:
List
[
str
]
=
[]
def
_get_config_file
(
self
,
qlora_adapter
:
str
)
->
str
:
def
_get_config_file
(
self
,
qlora_adapter
:
str
)
->
str
:
is_local
=
os
.
path
.
isdir
(
qlora_adapter
)
is_local
=
os
.
path
.
isdir
(
qlora_adapter
)
...
@@ -990,8 +993,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -990,8 +993,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
any
(
target_module
in
weight_name
for
target_module
in
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# Without sharding
if
any
(
module
in
weight_name
if
any
(
weight_name
.
startswith
(
module
)
for
module
in
self
.
unsharded_weights_modules
):
weight_sub_tensor
=
weight_tensor
# Shard by column
elif
any
(
module
in
weight_name
for
module
in
self
.
column_parallel_weights_modules
):
for
module
in
self
.
column_parallel_weights_modules
):
total_size
=
weight_tensor
.
size
(
-
1
)
total_size
=
weight_tensor
.
size
(
-
1
)
...
@@ -999,7 +1007,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -999,7 +1007,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[...,
weight_sub_tensor
=
weight_tensor
[...,
start_index
:
end_index
]
start_index
:
end_index
]
# Shard by row
else
:
else
:
total_size
=
weight_tensor
.
size
(
0
)
total_size
=
weight_tensor
.
size
(
0
)
start_index
=
total_size
//
tp_size
*
tp_rank
start_index
=
total_size
//
tp_size
*
tp_rank
...
@@ -1053,7 +1061,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1053,7 +1061,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model
.
column_parallel_weights_modules
model
.
column_parallel_weights_modules
else
:
else
:
self
.
column_parallel_weights_modules
=
[]
self
.
column_parallel_weights_modules
=
[]
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
# TODO: Can we reduce the static variables needed for BNB based on
# model information?
self
.
unsharded_weights_modules
=
[
name
for
name
,
module
in
model
.
named_modules
()
if
isinstance
(
module
,
(
ReplicatedLinear
,
))
]
self
.
model_type
=
type
(
model
).
__name__
self
.
model_type
=
type
(
model
).
__name__
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
...
@@ -1100,7 +1116,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1100,7 +1116,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for
shard_name
,
(
for
shard_name
,
(
weight_name
,
index
weight_name
,
index
)
in
model
.
bitsandbytes_stacked_params_mapping
.
items
():
)
in
model
.
bitsandbytes_stacked_params_mapping
.
items
():
if
shard_name
in
quant_param_name
:
shard_pos
=
quant_param_name
.
find
(
shard_name
)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
if
shard_pos
>
0
and
quant_param_name
[
shard_pos
-
1
]
==
"."
:
shard_index
=
index
shard_index
=
index
quant_param_name
=
quant_param_name
.
replace
(
quant_param_name
=
quant_param_name
.
replace
(
shard_name
,
weight_name
)
shard_name
,
weight_name
)
...
...
vllm/model_executor/models/minicpmv.py
View file @
c49f0407
...
@@ -131,16 +131,22 @@ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
...
@@ -131,16 +131,22 @@ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
class
Resampler2_5
(
BaseResampler
):
class
Resampler2_5
(
BaseResampler
):
def
__init__
(
def
__init__
(
self
,
self
,
num_queries
:
int
,
num_queries
:
int
,
embed_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
max_size
:
Tuple
[
int
,
int
]
=
(
70
,
70
),
max_size
:
Tuple
[
int
,
int
]
=
(
70
,
70
),
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
super
().
__init__
(
num_queries
,
embed_dim
,
num_heads
,
kv_dim
,
norm_layer
)
prefix
:
str
=
""
)
->
None
:
super
().
__init__
(
num_queries
,
embed_dim
,
num_heads
,
kv_dim
,
norm_layer
,
quant_config
=
quant_config
,
prefix
=
prefix
)
self
.
max_size
=
max_size
self
.
max_size
=
max_size
self
.
_set_2d_pos_cache
(
self
.
max_size
)
self
.
_set_2d_pos_cache
(
self
.
max_size
)
...
@@ -404,7 +410,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -404,7 +410,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self
.
vision_dim
=
(
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
else
self
.
vision_dim
=
(
self
.
vpm
.
embed_dim
if
self
.
version
==
(
2
,
0
)
else
self
.
vpm
.
embeddings
.
embed_dim
)
self
.
vpm
.
embeddings
.
embed_dim
)
self
.
embed_dim
=
self
.
config
.
hidden_size
self
.
embed_dim
=
self
.
config
.
hidden_size
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
)
self
.
resampler
=
self
.
init_resampler
(
self
.
embed_dim
,
self
.
vision_dim
,
quant_config
=
quant_config
,
prefix
=
"resampler"
)
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
self
.
resampler
.
to
(
device
=
"cuda"
,
dtype
=
param_dtype
)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
...
@@ -666,7 +675,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -666,7 +675,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
)
->
nn
.
Module
:
)
->
nn
.
Module
:
raise
NotImplementedError
raise
NotImplementedError
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
nn
.
Module
:
raise
NotImplementedError
raise
NotImplementedError
def
get_vision_embedding
(
def
get_vision_embedding
(
...
@@ -743,16 +756,21 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -743,16 +756,21 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_tokens
(
input_ids
)
return
self
.
model
.
embed_tokens
(
input_ids
)
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
with
set_default_torch_dtype
(
torch
.
float16
):
resampler
=
Resampler2
(
resampler
=
Resampler2
(
embed_dim
=
embed_dim
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
num_heads
=
embed_dim
//
128
,
grid_size
=
int
(
math
.
sqrt
(
self
.
config
.
query_num
)),
grid_size
=
int
(
math
.
sqrt
(
self
.
config
.
query_num
)),
kv_dim
=
vision_dim
,
kv_dim
=
vision_dim
,
adaptive
=
False
,
adaptive
=
False
,
do_post_projection
=
True
,
do_post_projection
=
True
,
)
quant_config
=
quant_config
,
prefix
=
prefix
)
return
resampler
return
resampler
...
@@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
".k_proj."
,
".k_proj."
,
".v_proj."
,
".v_proj."
,
".o_proj."
,
".o_proj."
,
# vision encoder
".fc1."
,
".fc2."
,
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj."
,
# vision encoder out_proj
# resampler
".kv_proj."
,
]
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
,
".self_attn.out_proj."
,
".fc2."
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
@@ -877,14 +907,18 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -877,14 +907,18 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
return
model
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
with
set_default_torch_dtype
(
torch
.
float16
):
resampler
=
Resampler2_5
(
resampler
=
Resampler2_5
(
num_queries
=
self
.
config
.
query_num
,
num_queries
=
self
.
config
.
query_num
,
embed_dim
=
embed_dim
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
kv_dim
=
vision_dim
,
)
quant_config
=
quant_config
,
prefix
=
prefix
)
return
resampler
return
resampler
def
get_vision_embedding
(
def
get_vision_embedding
(
...
@@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
".k_proj."
,
".k_proj."
,
".v_proj."
,
".v_proj."
,
".o_proj."
,
".o_proj."
,
# vision encoder
".fc1."
,
".fc2."
,
".self_attn.out_proj."
,
# resampler
".kv_proj."
,
]
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
,
".self_attn.out_proj."
,
".fc2."
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
@@ -1019,15 +1061,19 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
...
@@ -1019,15 +1061,19 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
model
.
encoder
.
layers
=
model
.
encoder
.
layers
[:
-
1
]
return
model
return
model
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
)
->
nn
.
Module
:
def
init_resampler
(
self
,
embed_dim
:
int
,
vision_dim
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
)
->
nn
.
Module
:
with
set_default_torch_dtype
(
torch
.
float16
):
with
set_default_torch_dtype
(
torch
.
float16
):
# The resampler in 2.6 remains consistent with the one in 2.5.
# The resampler in 2.6 remains consistent with the one in 2.5.
resampler
=
Resampler2_5
(
resampler
=
Resampler2_5
(
num_queries
=
self
.
config
.
query_num
,
num_queries
=
self
.
config
.
query_num
,
embed_dim
=
embed_dim
,
embed_dim
=
embed_dim
,
num_heads
=
embed_dim
//
128
,
num_heads
=
embed_dim
//
128
,
kv_dim
=
vision_dim
,
kv_dim
=
vision_dim
,
)
quant_config
=
quant_config
,
prefix
=
prefix
)
return
resampler
return
resampler
def
get_vision_embedding
(
def
get_vision_embedding
(
...
...
vllm/model_executor/models/mllama.py
View file @
c49f0407
...
@@ -1056,9 +1056,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1056,9 +1056,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
".k_proj."
,
".k_proj."
,
".v_proj."
,
".v_proj."
,
".o_proj."
,
".o_proj."
,
".fc1."
,
".fc2."
,
# The `multi_modal_projector` is at the top level of the model,
# so we can't add a dot in front of it.
"multi_modal_projector."
]
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
,
".fc2."
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment