Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f4117c3
Unverified
Commit
2f4117c3
authored
Oct 08, 2024
by
chenqianfzh
Committed by
GitHub
Oct 08, 2024
Browse files
support bitsandbytes quantization with more models (#9148)
parent
9ba0bd6a
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
164 additions
and
27 deletions
+164
-27
tests/quantization/test_bitsandbytes.py
tests/quantization/test_bitsandbytes.py
+7
-6
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+25
-1
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+2
-2
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+44
-18
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+11
-0
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+22
-0
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+13
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+13
-0
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+13
-0
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+14
-0
No files found.
tests/quantization/test_bitsandbytes.py
View file @
2f4117c3
...
@@ -9,22 +9,22 @@ import pytest
...
@@ -9,22 +9,22 @@ import pytest
import
torch
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.utils
import
fork_new_process_for_each_test
from
..utils
import
fork_new_process_for_each_test
models_4bit_to_test
=
[
models_4bit_to_test
=
[
(
'huggyllama/llama-7b'
,
'
quantize model inflight
'
),
(
"facebook/opt-125m"
,
"
quantize
opt
model inflight
"
),
]
]
models_pre_qaunt_4bit_to_test
=
[
models_pre_qaunt_4bit_to_test
=
[
(
'lllyasviel/omost-llama-3-8b-4bits'
,
'read pre-quantized 4-bit NF4 model'
),
(
'PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed'
,
(
'PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed'
,
'read pre-quantized 4-bit FP4 model'
),
'read pre-quantized 4-bit FP4 model'
),
(
'poedator/opt-125m-bnb-4bit'
,
'read pre-quantized 4-bit NF4 opt model'
),
]
]
models_pre_quant_8bit_to_test
=
[
models_pre_quant_8bit_to_test
=
[
(
'meta-llama/Llama-Guard-3-8B-INT8'
,
'read pre-quantized 8-bit model'
),
(
'meta-llama/Llama-Guard-3-8B-INT8'
,
'read pre-quantized llama 8-bit model'
),
(
"yec019/fbopt-350m-8bit"
,
"read pre-quantized 8-bit opt model"
),
]
]
...
@@ -133,6 +133,7 @@ def validate_generated_texts(hf_runner,
...
@@ -133,6 +133,7 @@ def validate_generated_texts(hf_runner,
hf_str
=
hf_log
[
"generated_text"
]
hf_str
=
hf_log
[
"generated_text"
]
vllm_str
=
vllm_log
[
"generated_text"
]
vllm_str
=
vllm_log
[
"generated_text"
]
prompt
=
hf_log
[
"prompt"
]
prompt
=
hf_log
[
"prompt"
]
assert
hf_str
==
vllm_str
,
(
f
"Model:
{
model_name
}
"
assert
hf_str
==
vllm_str
,
(
f
"Model:
{
model_name
}
"
f
"Mismatch between HF and vLLM outputs:
\n
"
f
"Mismatch between HF and vLLM outputs:
\n
"
f
"Prompt:
{
prompt
}
\n
"
f
"Prompt:
{
prompt
}
\n
"
...
...
vllm/model_executor/layers/linear.py
View file @
2f4117c3
...
@@ -336,8 +336,12 @@ class ColumnParallelLinear(LinearBase):
...
@@ -336,8 +336,12 @@ class ColumnParallelLinear(LinearBase):
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
param_data
=
param
.
data
param_data
=
param
.
data
if
output_dim
is
not
None
:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
output_dim
is
not
None
and
not
use_bitsandbytes_4bit
:
shard_size
=
param_data
.
shape
[
output_dim
]
shard_size
=
param_data
.
shape
[
output_dim
]
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
...
@@ -821,6 +825,9 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -821,6 +825,9 @@ class QKVParallelLinear(ColumnParallelLinear):
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
(
"v"
,
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
]
]
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantized Weights.
# Special case for Quantized Weights.
...
@@ -834,6 +841,23 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -834,6 +841,23 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
if
use_bitsandbytes_4bit
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
total_num_heads
*
self
.
head_size
),
"k"
:
(
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
"v"
:
((
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
),
"total"
:
((
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_size
,
0
)
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_4bit_shard
(
param
,
orig_qkv_offsets
,
shard_id
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
output_dim
,
shard_offset
,
shard_size
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
self
.
weight_loader
(
param
,
loaded_weight_shard
,
shard_id
)
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
2f4117c3
...
@@ -108,7 +108,7 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -108,7 +108,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
return
[]
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
class
BitsAndBytesLinearMethod
(
LinearMethodBase
):
...
@@ -236,7 +236,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
...
@@ -236,7 +236,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
if
generation
==
0
or
generation
==
1
:
if
generation
==
0
or
generation
==
1
:
matmul_states
[
i
]
=
MatmulLtState
()
matmul_states
[
i
]
=
MatmulLtState
()
matmul_states
[
i
].
CB
=
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
matmul_states
[
i
].
CB
=
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
matmul_states
[
i
].
SCB
=
quant_states
[
i
]
matmul_states
[
i
].
SCB
=
quant_states
[
i
]
.
to
(
x
.
device
)
matmul_states
[
i
].
threshold
=
(
matmul_states
[
i
].
threshold
=
(
self
.
quant_config
.
llm_int8_threshold
)
self
.
quant_config
.
llm_int8_threshold
)
matmul_states
[
i
].
has_fp16_weights
=
(
matmul_states
[
i
].
has_fp16_weights
=
(
...
...
vllm/model_executor/model_loader/loader.py
View file @
2f4117c3
...
@@ -736,15 +736,26 @@ class ShardedStateLoader(BaseModelLoader):
...
@@ -736,15 +736,26 @@ class ShardedStateLoader(BaseModelLoader):
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
"""Model loader to load model weights with BitAndBytes quantization."""
"""Model loader to load model weights with BitAndBytes quantization."""
# TODO: these modu
le
names
are for Llama only,
possible_config_fi
le
_
names
=
[
"adapter_config.json"
]
# change so that it works with other models as well
default_target_modules
=
[
default_target_modules
=
[
"gate_proj"
,
"down_proj"
,
"up_proj"
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
".gate_proj."
,
"o_proj"
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
'.fc1.'
,
'.fc2.'
,
'.dense.'
,
'.query_key_value.'
,
'.qkv_proj.'
,
'.dense_h_to_4h.'
,
'.dense_4h_to_h.'
,
'.out_proj.'
,
]
]
possible_config_file_names
=
[
"adapter_config.json"
]
def
__init__
(
self
,
load_config
:
LoadConfig
):
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
super
().
__init__
(
load_config
)
...
@@ -754,7 +765,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -754,7 +765,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
(
not
load_config
.
model_loader_extra_config
if
(
not
load_config
.
model_loader_extra_config
or
"qlora_adapter_name_or_path"
or
"qlora_adapter_name_or_path"
not
in
load_config
.
model_loader_extra_config
):
not
in
load_config
.
model_loader_extra_config
):
self
.
target_modules
=
self
.
default_target_modules
self
.
target_modules
=
[]
return
return
qlora_adapter
=
load_config
.
model_loader_extra_config
[
qlora_adapter
=
load_config
.
model_loader_extra_config
[
...
@@ -901,10 +912,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -901,10 +912,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
hf_weights_files
,
use_safetensors
):
if
not
weight_name
.
endswith
(
".weight"
):
if
not
weight_name
.
endswith
(
(
".weight"
,
".bias"
)
):
continue
continue
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
qweight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
if
qweight_name
in
quant_state_dict
:
if
qweight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
qweight_name
,
weight_tensor
yield
qweight_name
,
weight_tensor
...
@@ -920,7 +932,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -920,7 +932,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
use_safetensors
)
use_safetensors
)
temp_state_dict
=
{}
temp_state_dict
=
{}
for
weight_name
,
weight_tensor
in
weight_iterator
:
for
weight_name
,
weight_tensor
in
weight_iterator
:
if
weight_name
.
endswith
(
".weight"
):
if
weight_name
.
endswith
(
(
".weight"
,
".bias"
)
):
continue
continue
# bitsandbytes library requires
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
# weight.quant_state.bitsandbytes__* in CPU
...
@@ -943,9 +955,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -943,9 +955,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# pre quantized weights would have a quant_state
# pre quantized weights would have a quant_state
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
hf_weights_files
,
use_safetensors
):
# Filter out all weights whose suffix is not ".weight"
if
not
weight_name
.
endswith
(
".weight"
):
if
not
weight_name
.
endswith
(
(
".weight"
,
".bias"
)
):
continue
continue
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
\
if
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__nf4"
\
in
temp_state_dict
)
or
\
in
temp_state_dict
)
or
\
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
\
(
f
"
{
weight_name
}
.quant_state.bitsandbytes__fp4"
\
...
@@ -965,15 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -965,15 +978,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
for
weight_name
,
weight_tensor
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
):
if
any
(
target_module
in
weight_name
for
target_module
in
self
.
target_modules
)
and
weight_name
.
endswith
(
".weight"
):
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
weight_name
=
weight_name
.
replace
(
".weight"
,
".qweight"
)
# weight partitions of different modules occur at
if
any
(
module
in
weight_name
# different dimensions
for
module
in
self
.
column_parallel_weights_modules
):
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if
'down_proj'
in
weight_name
or
'o_proj'
in
weight_name
:
total_size
=
weight_tensor
.
size
(
-
1
)
total_size
=
weight_tensor
.
size
(
-
1
)
start_index
=
total_size
//
tp_size
*
tp_rank
start_index
=
total_size
//
tp_size
*
tp_rank
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
...
@@ -1022,6 +1034,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -1022,6 +1034,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet."
)
"quantization yet."
)
if
len
(
self
.
target_modules
)
==
0
:
if
hasattr
(
model
,
'default_bitsandbytes_target_modules'
):
self
.
target_modules
=
model
.
default_bitsandbytes_target_modules
else
:
self
.
target_modules
=
self
.
default_target_modules
if
hasattr
(
model
,
'column_parallel_weights_modules'
):
self
.
column_parallel_weights_modules
=
\
model
.
column_parallel_weights_modules
else
:
self
.
column_parallel_weights_modules
=
[]
self
.
model_type
=
type
(
model
).
__name__
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
" May take a while ..."
)
" May take a while ..."
)
...
...
vllm/model_executor/models/falcon.py
View file @
2f4117c3
...
@@ -391,6 +391,17 @@ class FalconModel(nn.Module):
...
@@ -391,6 +391,17 @@ class FalconModel(nn.Module):
class
FalconForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
FalconForCausalLM
(
nn
.
Module
,
SupportsPP
):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{}
default_bitsandbytes_target_modules
=
[
".query_key_value."
,
".dense."
,
".dense_h_to_4h."
,
".dense_4h_to_h."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".dense_4h_to_h."
,
".dense."
]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
FalconConfig
,
config
:
FalconConfig
,
...
...
vllm/model_executor/models/gemma.py
View file @
2f4117c3
...
@@ -332,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -332,6 +332,28 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_up_proj"
,
"gate_up_proj"
,
"down_proj"
,
"down_proj"
,
]
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
# Gemma does not apply LoRA to the embedding layer.
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
...
...
vllm/model_executor/models/gemma2.py
View file @
2f4117c3
...
@@ -375,6 +375,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -375,6 +375,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Gemma does not apply LoRA to the embedding layer.
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/llama.py
View file @
2f4117c3
...
@@ -449,6 +449,19 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -449,6 +449,19 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"lm_head"
:
"output_embeddings"
"lm_head"
:
"output_embeddings"
}
}
embedding_padding_modules
=
[
"lm_head"
]
embedding_padding_modules
=
[
"lm_head"
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules
=
[
".gate_proj."
,
".down_proj."
,
".up_proj."
,
".q_proj."
,
".k_proj."
,
".v_proj."
,
".o_proj."
,
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".down_proj."
,
".o_proj."
]
bitsandbytes_stacked_params_mapping
=
{
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"q_proj"
:
(
"qkv_proj"
,
0
),
...
...
vllm/model_executor/models/opt.py
View file @
2f4117c3
...
@@ -315,6 +315,19 @@ class OPTModel(nn.Module):
...
@@ -315,6 +315,19 @@ class OPTModel(nn.Module):
class
OPTForCausalLM
(
nn
.
Module
,
SupportsPP
):
class
OPTForCausalLM
(
nn
.
Module
,
SupportsPP
):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
}
default_bitsandbytes_target_modules
=
[
".q_proj."
,
".k_proj."
,
".v_proj."
,
".out_proj."
,
".fc1."
,
".fc2."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".out_proj."
,
".fc2."
]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
OPTConfig
,
config
:
OPTConfig
,
...
...
vllm/model_executor/models/phi.py
View file @
2f4117c3
...
@@ -260,6 +260,20 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -260,6 +260,20 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"fc1"
,
"fc1"
,
"fc2"
,
"fc2"
,
]
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
}
default_bitsandbytes_target_modules
=
[
".q_proj."
,
".k_proj."
,
".v_proj."
,
".fc1."
,
".fc2."
,
".dense."
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules
=
[
".fc2."
,
".dense."
]
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment