Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
XrayGLM_pytorch
Commits
8306539e
Unverified
Commit
8306539e
authored
May 30, 2023
by
MPU王荣胜
Committed by
GitHub
May 30, 2023
Browse files
add QLoRA
parent
2def144b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
265 additions
and
2 deletions
+265
-2
finetune_XrayGLM.py
finetune_XrayGLM.py
+8
-2
lora_mixin.py
lora_mixin.py
+257
-0
No files found.
finetune_XrayGLM.py
View file @
8306539e
...
...
@@ -6,7 +6,7 @@ from sat import mpu, get_args, get_tokenizer
from
sat.training.deepspeed_training
import
training_main
from
model
import
VisualGLMModel
from
sat.model.finetune
import
PTuningV2Mixin
from
sat.model.finetune.
lora_mixin
import
LoraMixin
from
lora_mixin
import
LoraMixin
class
FineTuneVisualGLMModel
(
VisualGLMModel
):
def
__init__
(
self
,
args
,
transformer
=
None
,
parallel_output
=
True
,
**
kw_args
):
...
...
@@ -17,6 +17,8 @@ class FineTuneVisualGLMModel(VisualGLMModel):
# If you use lora on other "normal" Transformer, just use it with head_first=False (by default)
self
.
add_mixin
(
"lora"
,
LoraMixin
(
args
.
num_layers
,
args
.
lora_rank
,
head_first
=
True
,
num_attention_heads
=
args
.
num_attention_heads
,
hidden_size_per_attention_head
=
args
.
hidden_size
//
args
.
num_attention_heads
,
layer_range
=
list
(
range
(
0
,
28
,
14
))),
reinit
=
True
)
# self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
elif
args
.
use_qlora
:
self
.
add_mixin
(
"lora"
,
LoraMixin
(
args
.
num_layers
,
args
.
lora_rank
,
head_first
=
True
,
num_attention_heads
=
args
.
num_attention_heads
,
hidden_size_per_attention_head
=
args
.
hidden_size
//
args
.
num_attention_heads
,
layer_range
=
list
(
range
(
0
,
28
,
14
)),
qlora
=
True
),
reinit
=
True
)
self
.
args
=
args
@
classmethod
...
...
@@ -26,13 +28,14 @@ class FineTuneVisualGLMModel(VisualGLMModel):
group
.
add_argument
(
'--lora_rank'
,
type
=
int
,
default
=
10
)
group
.
add_argument
(
'--use_ptuning'
,
action
=
"store_true"
)
group
.
add_argument
(
'--use_lora'
,
action
=
"store_true"
)
group
.
add_argument
(
'--use_qlora'
,
action
=
"store_true"
)
return
super
().
add_model_specific_args
(
parser
)
def
disable_untrainable_params
(
self
):
enable
=
[]
if
self
.
args
.
use_ptuning
:
enable
.
extend
([
'ptuning'
])
if
self
.
args
.
use_lora
:
if
self
.
args
.
use_lora
or
self
.
args
.
use_qlora
:
enable
.
extend
([
'matrix_A'
,
'matrix_B'
])
for
n
,
p
in
self
.
named_parameters
():
flag
=
False
...
...
@@ -169,9 +172,12 @@ if __name__ == '__main__':
known
,
args_list
=
py_parser
.
parse_known_args
()
args
=
get_args
(
args_list
)
args
=
argparse
.
Namespace
(
**
vars
(
args
),
**
vars
(
known
))
args
.
device
=
'cpu'
model_type
=
'visualglm-6b'
model
,
args
=
FineTuneVisualGLMModel
.
from_pretrained
(
model_type
,
args
)
if
torch
.
cuda
.
is_available
():
model
=
model
.
to
(
'cuda'
)
tokenizer
=
get_tokenizer
(
args
)
label_pad_token_id
=
-
100
if
args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
def
data_collator
(
examples
):
...
...
lora_mixin.py
0 → 100644
View file @
8306539e
"""
In this mixin, I use a different implementation than lora.py
I just use a fake linear layer to replace any model with lora mixin.
"""
import
torch
import
torch.nn
as
nn
from
sat.model.base_model
import
BaseMixin
import
math
from
sat.helpers
import
print_all
from
sat.model.transformer
import
RowParallelLinear
,
ColumnParallelLinear
class
HackLinear
(
nn
.
Linear
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
if
prefix
+
'weight'
in
state_dict
:
self
.
weight
.
data
.
copy_
(
state_dict
[
prefix
+
'weight'
])
if
prefix
+
'bias'
in
state_dict
:
self
.
bias
.
data
.
copy_
(
state_dict
[
prefix
+
'bias'
])
try
:
from
bitsandbytes.nn
import
LinearNF4
def
copy_nested_list
(
src
,
dst
):
for
i
in
range
(
len
(
dst
)):
if
type
(
dst
[
i
])
is
torch
.
Tensor
:
dst
[
i
].
copy_
(
src
[
i
])
elif
type
(
dst
[
i
])
is
list
:
copy_nested_list
(
src
[
i
],
dst
[
i
])
else
:
dst
[
i
]
=
src
[
i
]
class
HackLinearNF4
(
LinearNF4
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
if
prefix
+
'weight'
in
state_dict
:
self
.
weight
.
data
.
copy_
(
state_dict
[
prefix
+
'weight'
])
if
self
.
weight
.
data
.
dtype
==
torch
.
uint8
:
copy_nested_list
(
state_dict
[
prefix
+
'quant_state'
],
self
.
weight
.
quant_state
)
if
prefix
+
'bias'
in
state_dict
:
self
.
bias
.
data
.
copy_
(
state_dict
[
prefix
+
'bias'
])
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
destination
[
prefix
+
'quant_state'
]
=
self
.
weight
.
quant_state
except
Exception
as
exception
:
print_all
(
"Failed to load bitsandbytes:"
+
str
(
exception
),
level
=
'WARNING'
)
class
HackParameterList
(
nn
.
ParameterList
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
for
i
in
range
(
len
(
self
)):
if
prefix
+
str
(
i
)
in
state_dict
:
self
[
i
].
data
.
copy_
(
state_dict
[
prefix
+
str
(
i
)])
class
LoraLinear
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
r
,
lora_alpha
=
1.
,
lora_dropout
=
0.
,
qlora
=
False
):
super
().
__init__
()
if
lora_dropout
and
lora_dropout
>
0
:
self
.
lora_dropout
=
nn
.
Dropout
(
p
=
lora_dropout
)
else
:
self
.
lora_dropout
=
lambda
x
:
x
self
.
r
=
r
self
.
lora_alpha
=
lora_alpha
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
if
qlora
:
self
.
original
=
HackLinearNF4
(
in_dim
,
out_dim
)
else
:
self
.
original
=
HackLinear
(
in_dim
,
out_dim
)
self
.
matrix_A
=
nn
.
Parameter
(
torch
.
empty
((
r
,
in_dim
)))
self
.
matrix_B
=
nn
.
Parameter
(
torch
.
empty
((
out_dim
,
r
)))
nn
.
init
.
kaiming_uniform_
(
self
.
matrix_A
,
a
=
math
.
sqrt
(
5
))
nn
.
init
.
zeros_
(
self
.
matrix_B
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
# This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
if
prefix
+
'weight'
in
state_dict
:
# load from normal Linear
self
.
original
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
else
:
# load from LoraLinear
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
def
forward
(
self
,
x
):
return
self
.
original
(
x
)
+
(
self
.
lora_dropout
(
x
)
@
self
.
matrix_A
.
T
@
self
.
matrix_B
.
T
)
*
self
.
scaling
class
LoraQKV
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
r
,
lora_alpha
=
1.
,
lora_dropout
=
0.
,
head_first
=
False
,
num_attention_heads
=
None
,
hidden_size_per_attention_head
=
None
,
qlora
=
False
):
"""
You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order.
If you use a different order like ChatGLM
"""
super
().
__init__
()
if
lora_dropout
and
lora_dropout
>
0
:
self
.
lora_dropout
=
nn
.
Dropout
(
p
=
lora_dropout
)
else
:
self
.
lora_dropout
=
lambda
x
:
x
self
.
r
=
r
self
.
lora_alpha
=
lora_alpha
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
if
qlora
:
self
.
original
=
HackLinearNF4
(
in_dim
,
out_dim
)
else
:
self
.
original
=
HackLinear
(
in_dim
,
out_dim
)
self
.
matrix_A
=
HackParameterList
([
nn
.
Parameter
(
torch
.
empty
((
r
,
in_dim
)))
for
_
in
range
(
3
)])
self
.
matrix_B
=
HackParameterList
([
nn
.
Parameter
(
torch
.
empty
((
out_dim
//
3
,
r
)))
for
_
in
range
(
3
)])
for
i
in
range
(
3
):
nn
.
init
.
kaiming_uniform_
(
self
.
matrix_A
[
i
],
a
=
math
.
sqrt
(
5
))
nn
.
init
.
zeros_
(
self
.
matrix_B
[
i
])
self
.
head_first
=
head_first
if
head_first
:
assert
num_attention_heads
is
not
None
and
hidden_size_per_attention_head
is
not
None
,
"You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!"
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_size_per_attention_head
=
hidden_size_per_attention_head
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
# This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
if
prefix
+
'weight'
in
state_dict
:
# load from normal Linear
self
.
original
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
else
:
# load from LoraLinear
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
def
forward
(
self
,
x
):
mixed_raw_layer
=
self
.
original
(
x
)
lora_outputs
=
[]
for
i
in
range
(
3
):
lora_outputs
.
append
((
self
.
lora_dropout
(
x
)
@
self
.
matrix_A
[
i
].
T
@
self
.
matrix_B
[
i
].
T
)
*
self
.
scaling
)
if
self
.
head_first
:
new_tensor_shape
=
lora_outputs
[
0
].
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
)
for
i
in
range
(
3
):
lora_outputs
[
i
]
=
lora_outputs
[
i
].
view
(
*
new_tensor_shape
)
mixed_raw_layer
=
mixed_raw_layer
+
torch
.
cat
(
lora_outputs
,
-
1
).
view
(
*
mixed_raw_layer
.
size
())
else
:
mixed_raw_layer
=
mixed_raw_layer
+
torch
.
cat
(
lora_outputs
,
-
1
)
return
mixed_raw_layer
def
replace_linear_with_lora
(
lin
,
base_cls
,
r
,
*
args
,
**
kw_args
):
# not supported for linear without bias for now
out_dim
,
in_dim
=
lin
.
weight
.
shape
return
base_cls
(
in_dim
,
out_dim
,
r
,
*
args
,
**
kw_args
)
def
merge_linear_lora
(
lin
):
out_dim
,
in_dim
=
lin
.
original
.
weight
.
shape
new_lin
=
nn
.
Linear
(
in_dim
,
out_dim
)
new_lin
.
bias
.
data
=
lin
.
original
.
bias
.
data
new_lin
.
weight
.
data
=
lin
.
original
.
weight
.
data
+
(
lin
.
matrix_A
.
data
.
T
.
float
()
@
lin
.
matrix_B
.
data
.
T
.
float
()
*
lin
.
scaling
).
T
.
to
(
lin
.
original
.
weight
.
data
.
dtype
)
return
new_lin
def
merge_qkv_lora
(
lin
):
out_dim
,
in_dim
=
lin
.
original
.
weight
.
shape
new_lin
=
nn
.
Linear
(
in_dim
,
out_dim
)
new_lin
.
bias
.
data
=
lin
.
original
.
bias
.
data
new_qkv
=
[]
for
i
in
range
(
3
):
new_qkv
.
append
(
lin
.
matrix_A
[
i
].
data
.
T
.
float
()
@
lin
.
matrix_B
[
i
].
data
.
T
.
float
()
*
lin
.
scaling
)
if
lin
.
head_first
:
ini_shape
=
new_qkv
[
0
].
shape
new_qkv
=
[
x
.
view
(
ini_shape
[
0
],
lin
.
num_attention_heads
,
-
1
)
for
x
in
new_qkv
]
new_qkv
=
torch
.
cat
(
new_qkv
,
-
1
).
view
(
ini_shape
[
0
],
3
*
ini_shape
[
1
])
else
:
new_qkv
=
torch
.
cat
(
new_qkv
,
-
1
)
new_lin
.
weight
.
data
=
lin
.
original
.
weight
.
data
+
new_qkv
.
T
.
to
(
lin
.
original
.
weight
.
data
.
dtype
)
return
new_lin
class
LoraMixin
(
BaseMixin
):
def
__init__
(
self
,
layer_num
,
r
:
int
=
0
,
lora_alpha
:
int
=
1
,
lora_dropout
:
float
=
0.
,
layer_range
=
None
,
head_first
=
False
,
num_attention_heads
=
None
,
hidden_size_per_attention_head
=
None
,
qlora
=
False
):
super
().
__init__
()
self
.
r
=
r
self
.
lora_alpha
=
lora_alpha
self
.
lora_dropout
=
lora_dropout
if
layer_range
is
None
:
layer_range
=
[
i
for
i
in
range
(
layer_num
)]
self
.
layer_range
=
layer_range
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
self
.
head_first
=
head_first
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_size_per_attention_head
=
hidden_size_per_attention_head
self
.
qlora
=
qlora
def
reinit
(
self
,
parent_model
):
"""
only support self-attention part
not supported for cross-attention for now
"""
for
i
in
self
.
layer_range
:
print
(
f
'replacing layer
{
i
}
with lora'
)
parent_model
.
transformer
.
layers
[
i
].
attention
.
dense
=
replace_linear_with_lora
(
parent_model
.
transformer
.
layers
[
i
].
attention
.
dense
,
LoraLinear
,
self
.
r
,
self
.
lora_alpha
,
self
.
lora_dropout
,
self
.
qlora
)
parent_model
.
transformer
.
layers
[
i
].
attention
.
query_key_value
=
replace_linear_with_lora
(
parent_model
.
transformer
.
layers
[
i
].
attention
.
query_key_value
,
LoraQKV
,
self
.
r
,
self
.
lora_alpha
,
self
.
lora_dropout
,
head_first
=
self
.
head_first
,
num_attention_heads
=
self
.
num_attention_heads
,
hidden_size_per_attention_head
=
self
.
hidden_size_per_attention_head
,
qlora
=
self
.
qlora
)
if
self
.
qlora
:
print
(
'replacing chatglm linear layer with 4bit'
)
def
replace_linear_with_nf4
(
model
,
name
=
None
,
cache
=
{}):
if
type
(
model
)
in
(
nn
.
Linear
,
RowParallelLinear
,
ColumnParallelLinear
):
out_dim
,
in_dim
=
model
.
weight
.
shape
return
HackLinearNF4
(
in_dim
,
out_dim
)
names
=
set
()
for
name
,
child
in
model
.
named_children
():
if
name
not
in
names
:
if
child
in
cache
:
new_child
=
cache
[
child
]
else
:
new_child
=
replace_linear_with_nf4
(
child
,
name
=
name
,
cache
=
cache
)
cache
[
child
]
=
new_child
setattr
(
model
,
name
,
new_child
)
names
.
add
(
name
)
flag
=
True
while
flag
:
flag
=
False
for
name
,
child
in
model
.
named_children
():
if
name
not
in
names
:
setattr
(
model
,
name
,
cache
[
child
])
names
.
add
(
name
)
flag
=
True
return
model
replace_linear_with_nf4
(
parent_model
.
transformer
,
None
,
{})
def
merge_lora
(
self
):
for
i
in
self
.
layer_range
:
print
(
f
'merge layer
{
i
}
lora back to linear'
)
self
.
transformer
.
layers
[
i
].
attention
.
dense
=
merge_linear_lora
(
self
.
transformer
.
layers
[
i
].
attention
.
dense
)
self
.
transformer
.
layers
[
i
].
attention
.
query_key_value
=
merge_qkv_lora
(
self
.
transformer
.
layers
[
i
].
attention
.
query_key_value
)
if
__name__
==
'__main__'
:
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
child
=
nn
.
Linear
(
100
,
200
)
def
forward
(
self
,
x
):
return
self
.
child
(
x
)
model
=
Model
()
torch
.
save
(
model
.
state_dict
(),
"linear.pt"
)
x
=
torch
.
randn
(
2
,
100
)
out1
=
model
(
x
)
model
.
child
=
LoraLinear
(
100
,
200
,
10
)
model
.
load_state_dict
(
torch
.
load
(
"linear.pt"
),
strict
=
False
)
out2
=
model
(
x
)
torch
.
save
(
model
.
state_dict
(),
"lora.pt"
)
ckpt
=
torch
.
load
(
"lora.pt"
)
breakpoint
()
model
.
load_state_dict
(
ckpt
,
strict
=
False
)
out3
=
model
(
x
)
breakpoint
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment