Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
XrayGLM_pytorch
Commits
8306539e
You need to sign in or sign up before continuing.
Unverified
Commit
8306539e
authored
May 30, 2023
by
MPU王荣胜
Committed by
GitHub
May 30, 2023
Browse files
add QLoRA
parent
2def144b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
265 additions
and
2 deletions
+265
-2
finetune_XrayGLM.py
finetune_XrayGLM.py
+8
-2
lora_mixin.py
lora_mixin.py
+257
-0
No files found.
finetune_XrayGLM.py
View file @
8306539e
...
...
@@ -6,7 +6,7 @@ from sat import mpu, get_args, get_tokenizer
from
sat.training.deepspeed_training
import
training_main
from
model
import
VisualGLMModel
from
sat.model.finetune
import
PTuningV2Mixin
from
sat.model.finetune.
lora_mixin
import
LoraMixin
from
lora_mixin
import
LoraMixin
class
FineTuneVisualGLMModel
(
VisualGLMModel
):
def
__init__
(
self
,
args
,
transformer
=
None
,
parallel_output
=
True
,
**
kw_args
):
...
...
@@ -17,6 +17,8 @@ class FineTuneVisualGLMModel(VisualGLMModel):
# If you use lora on other "normal" Transformer, just use it with head_first=False (by default)
self
.
add_mixin
(
"lora"
,
LoraMixin
(
args
.
num_layers
,
args
.
lora_rank
,
head_first
=
True
,
num_attention_heads
=
args
.
num_attention_heads
,
hidden_size_per_attention_head
=
args
.
hidden_size
//
args
.
num_attention_heads
,
layer_range
=
list
(
range
(
0
,
28
,
14
))),
reinit
=
True
)
# self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
elif
args
.
use_qlora
:
self
.
add_mixin
(
"lora"
,
LoraMixin
(
args
.
num_layers
,
args
.
lora_rank
,
head_first
=
True
,
num_attention_heads
=
args
.
num_attention_heads
,
hidden_size_per_attention_head
=
args
.
hidden_size
//
args
.
num_attention_heads
,
layer_range
=
list
(
range
(
0
,
28
,
14
)),
qlora
=
True
),
reinit
=
True
)
self
.
args
=
args
@
classmethod
...
...
@@ -26,13 +28,14 @@ class FineTuneVisualGLMModel(VisualGLMModel):
group
.
add_argument
(
'--lora_rank'
,
type
=
int
,
default
=
10
)
group
.
add_argument
(
'--use_ptuning'
,
action
=
"store_true"
)
group
.
add_argument
(
'--use_lora'
,
action
=
"store_true"
)
group
.
add_argument
(
'--use_qlora'
,
action
=
"store_true"
)
return
super
().
add_model_specific_args
(
parser
)
def
disable_untrainable_params
(
self
):
enable
=
[]
if
self
.
args
.
use_ptuning
:
enable
.
extend
([
'ptuning'
])
if
self
.
args
.
use_lora
:
if
self
.
args
.
use_lora
or
self
.
args
.
use_qlora
:
enable
.
extend
([
'matrix_A'
,
'matrix_B'
])
for
n
,
p
in
self
.
named_parameters
():
flag
=
False
...
...
@@ -169,9 +172,12 @@ if __name__ == '__main__':
known
,
args_list
=
py_parser
.
parse_known_args
()
args
=
get_args
(
args_list
)
args
=
argparse
.
Namespace
(
**
vars
(
args
),
**
vars
(
known
))
args
.
device
=
'cpu'
model_type
=
'visualglm-6b'
model
,
args
=
FineTuneVisualGLMModel
.
from_pretrained
(
model_type
,
args
)
if
torch
.
cuda
.
is_available
():
model
=
model
.
to
(
'cuda'
)
tokenizer
=
get_tokenizer
(
args
)
label_pad_token_id
=
-
100
if
args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
def
data_collator
(
examples
):
...
...
lora_mixin.py
0 → 100644
View file @
8306539e
"""
In this mixin, I use a different implementation than lora.py
I just use a fake linear layer to replace any model with lora mixin.
"""
import
torch
import
torch.nn
as
nn
from
sat.model.base_model
import
BaseMixin
import
math
from
sat.helpers
import
print_all
from
sat.model.transformer
import
RowParallelLinear
,
ColumnParallelLinear
class
HackLinear
(
nn
.
Linear
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
if
prefix
+
'weight'
in
state_dict
:
self
.
weight
.
data
.
copy_
(
state_dict
[
prefix
+
'weight'
])
if
prefix
+
'bias'
in
state_dict
:
self
.
bias
.
data
.
copy_
(
state_dict
[
prefix
+
'bias'
])
try
:
from
bitsandbytes.nn
import
LinearNF4
def
copy_nested_list
(
src
,
dst
):
for
i
in
range
(
len
(
dst
)):
if
type
(
dst
[
i
])
is
torch
.
Tensor
:
dst
[
i
].
copy_
(
src
[
i
])
elif
type
(
dst
[
i
])
is
list
:
copy_nested_list
(
src
[
i
],
dst
[
i
])
else
:
dst
[
i
]
=
src
[
i
]
class
HackLinearNF4
(
LinearNF4
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
if
prefix
+
'weight'
in
state_dict
:
self
.
weight
.
data
.
copy_
(
state_dict
[
prefix
+
'weight'
])
if
self
.
weight
.
data
.
dtype
==
torch
.
uint8
:
copy_nested_list
(
state_dict
[
prefix
+
'quant_state'
],
self
.
weight
.
quant_state
)
if
prefix
+
'bias'
in
state_dict
:
self
.
bias
.
data
.
copy_
(
state_dict
[
prefix
+
'bias'
])
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
destination
[
prefix
+
'quant_state'
]
=
self
.
weight
.
quant_state
except
Exception
as
exception
:
print_all
(
"Failed to load bitsandbytes:"
+
str
(
exception
),
level
=
'WARNING'
)
class
HackParameterList
(
nn
.
ParameterList
):
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
for
i
in
range
(
len
(
self
)):
if
prefix
+
str
(
i
)
in
state_dict
:
self
[
i
].
data
.
copy_
(
state_dict
[
prefix
+
str
(
i
)])
class
LoraLinear
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
r
,
lora_alpha
=
1.
,
lora_dropout
=
0.
,
qlora
=
False
):
super
().
__init__
()
if
lora_dropout
and
lora_dropout
>
0
:
self
.
lora_dropout
=
nn
.
Dropout
(
p
=
lora_dropout
)
else
:
self
.
lora_dropout
=
lambda
x
:
x
self
.
r
=
r
self
.
lora_alpha
=
lora_alpha
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
if
qlora
:
self
.
original
=
HackLinearNF4
(
in_dim
,
out_dim
)
else
:
self
.
original
=
HackLinear
(
in_dim
,
out_dim
)
self
.
matrix_A
=
nn
.
Parameter
(
torch
.
empty
((
r
,
in_dim
)))
self
.
matrix_B
=
nn
.
Parameter
(
torch
.
empty
((
out_dim
,
r
)))
nn
.
init
.
kaiming_uniform_
(
self
.
matrix_A
,
a
=
math
.
sqrt
(
5
))
nn
.
init
.
zeros_
(
self
.
matrix_B
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
# This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
if
prefix
+
'weight'
in
state_dict
:
# load from normal Linear
self
.
original
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
else
:
# load from LoraLinear
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
def
forward
(
self
,
x
):
return
self
.
original
(
x
)
+
(
self
.
lora_dropout
(
x
)
@
self
.
matrix_A
.
T
@
self
.
matrix_B
.
T
)
*
self
.
scaling
class
LoraQKV
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
r
,
lora_alpha
=
1.
,
lora_dropout
=
0.
,
head_first
=
False
,
num_attention_heads
=
None
,
hidden_size_per_attention_head
=
None
,
qlora
=
False
):
"""
You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order.
If you use a different order like ChatGLM
"""
super
().
__init__
()
if
lora_dropout
and
lora_dropout
>
0
:
self
.
lora_dropout
=
nn
.
Dropout
(
p
=
lora_dropout
)
else
:
self
.
lora_dropout
=
lambda
x
:
x
self
.
r
=
r
self
.
lora_alpha
=
lora_alpha
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
if
qlora
:
self
.
original
=
HackLinearNF4
(
in_dim
,
out_dim
)
else
:
self
.
original
=
HackLinear
(
in_dim
,
out_dim
)
self
.
matrix_A
=
HackParameterList
([
nn
.
Parameter
(
torch
.
empty
((
r
,
in_dim
)))
for
_
in
range
(
3
)])
self
.
matrix_B
=
HackParameterList
([
nn
.
Parameter
(
torch
.
empty
((
out_dim
//
3
,
r
)))
for
_
in
range
(
3
)])
for
i
in
range
(
3
):
nn
.
init
.
kaiming_uniform_
(
self
.
matrix_A
[
i
],
a
=
math
.
sqrt
(
5
))
nn
.
init
.
zeros_
(
self
.
matrix_B
[
i
])
self
.
head_first
=
head_first
if
head_first
:
assert
num_attention_heads
is
not
None
and
hidden_size_per_attention_head
is
not
None
,
"You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!"
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_size_per_attention_head
=
hidden_size_per_attention_head
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
):
# This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
if
prefix
+
'weight'
in
state_dict
:
# load from normal Linear
self
.
original
.
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
else
:
# load from LoraLinear
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
)
def
forward
(
self
,
x
):
mixed_raw_layer
=
self
.
original
(
x
)
lora_outputs
=
[]
for
i
in
range
(
3
):
lora_outputs
.
append
((
self
.
lora_dropout
(
x
)
@
self
.
matrix_A
[
i
].
T
@
self
.
matrix_B
[
i
].
T
)
*
self
.
scaling
)
if
self
.
head_first
:
new_tensor_shape
=
lora_outputs
[
0
].
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
)
for
i
in
range
(
3
):
lora_outputs
[
i
]
=
lora_outputs
[
i
].
view
(
*
new_tensor_shape
)
mixed_raw_layer
=
mixed_raw_layer
+
torch
.
cat
(
lora_outputs
,
-
1
).
view
(
*
mixed_raw_layer
.
size
())
else
:
mixed_raw_layer
=
mixed_raw_layer
+
torch
.
cat
(
lora_outputs
,
-
1
)
return
mixed_raw_layer
def
replace_linear_with_lora
(
lin
,
base_cls
,
r
,
*
args
,
**
kw_args
):
# not supported for linear without bias for now
out_dim
,
in_dim
=
lin
.
weight
.
shape
return
base_cls
(
in_dim
,
out_dim
,
r
,
*
args
,
**
kw_args
)
def
merge_linear_lora
(
lin
):
out_dim
,
in_dim
=
lin
.
original
.
weight
.
shape
new_lin
=
nn
.
Linear
(
in_dim
,
out_dim
)
new_lin
.
bias
.
data
=
lin
.
original
.
bias
.
data
new_lin
.
weight
.
data
=
lin
.
original
.
weight
.
data
+
(
lin
.
matrix_A
.
data
.
T
.
float
()
@
lin
.
matrix_B
.
data
.
T
.
float
()
*
lin
.
scaling
).
T
.
to
(
lin
.
original
.
weight
.
data
.
dtype
)
return
new_lin
def
merge_qkv_lora
(
lin
):
out_dim
,
in_dim
=
lin
.
original
.
weight
.
shape
new_lin
=
nn
.
Linear
(
in_dim
,
out_dim
)
new_lin
.
bias
.
data
=
lin
.
original
.
bias
.
data
new_qkv
=
[]
for
i
in
range
(
3
):
new_qkv
.
append
(
lin
.
matrix_A
[
i
].
data
.
T
.
float
()
@
lin
.
matrix_B
[
i
].
data
.
T
.
float
()
*
lin
.
scaling
)
if
lin
.
head_first
:
ini_shape
=
new_qkv
[
0
].
shape
new_qkv
=
[
x
.
view
(
ini_shape
[
0
],
lin
.
num_attention_heads
,
-
1
)
for
x
in
new_qkv
]
new_qkv
=
torch
.
cat
(
new_qkv
,
-
1
).
view
(
ini_shape
[
0
],
3
*
ini_shape
[
1
])
else
:
new_qkv
=
torch
.
cat
(
new_qkv
,
-
1
)
new_lin
.
weight
.
data
=
lin
.
original
.
weight
.
data
+
new_qkv
.
T
.
to
(
lin
.
original
.
weight
.
data
.
dtype
)
return
new_lin
class
LoraMixin
(
BaseMixin
):
def
__init__
(
self
,
layer_num
,
r
:
int
=
0
,
lora_alpha
:
int
=
1
,
lora_dropout
:
float
=
0.
,
layer_range
=
None
,
head_first
=
False
,
num_attention_heads
=
None
,
hidden_size_per_attention_head
=
None
,
qlora
=
False
):
super
().
__init__
()
self
.
r
=
r
self
.
lora_alpha
=
lora_alpha
self
.
lora_dropout
=
lora_dropout
if
layer_range
is
None
:
layer_range
=
[
i
for
i
in
range
(
layer_num
)]
self
.
layer_range
=
layer_range
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
self
.
head_first
=
head_first
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_size_per_attention_head
=
hidden_size_per_attention_head
self
.
qlora
=
qlora
def
reinit
(
self
,
parent_model
):
"""
only support self-attention part
not supported for cross-attention for now
"""
for
i
in
self
.
layer_range
:
print
(
f
'replacing layer
{
i
}
with lora'
)
parent_model
.
transformer
.
layers
[
i
].
attention
.
dense
=
replace_linear_with_lora
(
parent_model
.
transformer
.
layers
[
i
].
attention
.
dense
,
LoraLinear
,
self
.
r
,
self
.
lora_alpha
,
self
.
lora_dropout
,
self
.
qlora
)
parent_model
.
transformer
.
layers
[
i
].
attention
.
query_key_value
=
replace_linear_with_lora
(
parent_model
.
transformer
.
layers
[
i
].
attention
.
query_key_value
,
LoraQKV
,
self
.
r
,
self
.
lora_alpha
,
self
.
lora_dropout
,
head_first
=
self
.
head_first
,
num_attention_heads
=
self
.
num_attention_heads
,
hidden_size_per_attention_head
=
self
.
hidden_size_per_attention_head
,
qlora
=
self
.
qlora
)
if
self
.
qlora
:
print
(
'replacing chatglm linear layer with 4bit'
)
def
replace_linear_with_nf4
(
model
,
name
=
None
,
cache
=
{}):
if
type
(
model
)
in
(
nn
.
Linear
,
RowParallelLinear
,
ColumnParallelLinear
):
out_dim
,
in_dim
=
model
.
weight
.
shape
return
HackLinearNF4
(
in_dim
,
out_dim
)
names
=
set
()
for
name
,
child
in
model
.
named_children
():
if
name
not
in
names
:
if
child
in
cache
:
new_child
=
cache
[
child
]
else
:
new_child
=
replace_linear_with_nf4
(
child
,
name
=
name
,
cache
=
cache
)
cache
[
child
]
=
new_child
setattr
(
model
,
name
,
new_child
)
names
.
add
(
name
)
flag
=
True
while
flag
:
flag
=
False
for
name
,
child
in
model
.
named_children
():
if
name
not
in
names
:
setattr
(
model
,
name
,
cache
[
child
])
names
.
add
(
name
)
flag
=
True
return
model
replace_linear_with_nf4
(
parent_model
.
transformer
,
None
,
{})
def
merge_lora
(
self
):
for
i
in
self
.
layer_range
:
print
(
f
'merge layer
{
i
}
lora back to linear'
)
self
.
transformer
.
layers
[
i
].
attention
.
dense
=
merge_linear_lora
(
self
.
transformer
.
layers
[
i
].
attention
.
dense
)
self
.
transformer
.
layers
[
i
].
attention
.
query_key_value
=
merge_qkv_lora
(
self
.
transformer
.
layers
[
i
].
attention
.
query_key_value
)
if
__name__
==
'__main__'
:
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
child
=
nn
.
Linear
(
100
,
200
)
def
forward
(
self
,
x
):
return
self
.
child
(
x
)
model
=
Model
()
torch
.
save
(
model
.
state_dict
(),
"linear.pt"
)
x
=
torch
.
randn
(
2
,
100
)
out1
=
model
(
x
)
model
.
child
=
LoraLinear
(
100
,
200
,
10
)
model
.
load_state_dict
(
torch
.
load
(
"linear.pt"
),
strict
=
False
)
out2
=
model
(
x
)
torch
.
save
(
model
.
state_dict
(),
"lora.pt"
)
ckpt
=
torch
.
load
(
"lora.pt"
)
breakpoint
()
model
.
load_state_dict
(
ckpt
,
strict
=
False
)
out3
=
model
(
x
)
breakpoint
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment