Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
CIRI-deep_tensorflow
Commits
4c497230
You need to sign in or sign up before continuing.
Unverified
Commit
4c497230
authored
May 25, 2023
by
MPU王荣胜
Committed by
GitHub
May 25, 2023
Browse files
add finetune
parent
62f5e989
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
188 additions
and
0 deletions
+188
-0
finetune_XrayGLM.py
finetune_XrayGLM.py
+188
-0
No files found.
finetune_XrayGLM.py
0 → 100644
View file @
4c497230
import
os
import
torch
import
argparse
from
sat
import
mpu
,
get_args
,
get_tokenizer
from
sat.training.deepspeed_training
import
training_main
from
model
import
VisualGLMModel
from
sat.model.finetune
import
PTuningV2Mixin
from
sat.model.finetune.lora_mixin
import
LoraMixin
class
FineTuneVisualGLMModel
(
VisualGLMModel
):
def
__init__
(
self
,
args
,
transformer
=
None
,
parallel_output
=
True
,
**
kw_args
):
super
().
__init__
(
args
,
transformer
=
transformer
,
parallel_output
=
parallel_output
,
**
kw_args
)
if
args
.
use_ptuning
:
self
.
add_mixin
(
"ptuning"
,
PTuningV2Mixin
(
args
.
num_layers
,
args
.
hidden_size
//
args
.
num_attention_heads
,
args
.
num_attention_heads
,
args
.
pre_seq_len
))
if
args
.
use_lora
:
# If you use lora on other "normal" Transformer, just use it with head_first=False (by default)
self
.
add_mixin
(
"lora"
,
LoraMixin
(
args
.
num_layers
,
args
.
lora_rank
,
head_first
=
True
,
num_attention_heads
=
args
.
num_attention_heads
,
hidden_size_per_attention_head
=
args
.
hidden_size
//
args
.
num_attention_heads
,
layer_range
=
list
(
range
(
0
,
28
,
14
))),
reinit
=
True
)
# self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
self
.
args
=
args
@
classmethod
def
add_model_specific_args
(
cls
,
parser
):
group
=
parser
.
add_argument_group
(
'VisualGLM-finetune'
,
'VisualGLM finetune Configurations'
)
group
.
add_argument
(
'--pre_seq_len'
,
type
=
int
,
default
=
8
)
group
.
add_argument
(
'--lora_rank'
,
type
=
int
,
default
=
10
)
group
.
add_argument
(
'--use_ptuning'
,
action
=
"store_true"
)
group
.
add_argument
(
'--use_lora'
,
action
=
"store_true"
)
return
super
().
add_model_specific_args
(
parser
)
def
disable_untrainable_params
(
self
):
enable
=
[]
if
self
.
args
.
use_ptuning
:
enable
.
extend
([
'ptuning'
])
if
self
.
args
.
use_lora
:
enable
.
extend
([
'matrix_A'
,
'matrix_B'
])
for
n
,
p
in
self
.
named_parameters
():
flag
=
False
for
e
in
enable
:
if
e
.
lower
()
in
n
.
lower
():
flag
=
True
break
if
not
flag
:
p
.
requires_grad_
(
False
)
else
:
print
(
n
)
def
get_batch
(
data_iterator
,
args
,
timers
):
# Items and their type.
keys
=
[
'input_ids'
,
'labels'
]
datatype
=
torch
.
int64
# Broadcast data.
timers
(
'data loader'
).
start
()
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_i
=
mpu
.
broadcast_data
([
'image'
],
data
,
torch
.
float32
)
# Unpack.
tokens
=
data_b
[
'input_ids'
].
long
()
labels
=
data_b
[
'labels'
].
long
()
img
=
data_i
[
'image'
]
if
args
.
fp16
:
img
=
img
.
half
()
return
tokens
,
labels
,
img
,
data
[
'pre_image'
]
from
torch.nn
import
CrossEntropyLoss
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
"""Forward step."""
# Get the batch.
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
image
,
pre_image
=
get_batch
(
data_iterator
,
args
,
timers
)
timers
(
'batch generator'
).
stop
()
logits
=
model
(
input_ids
=
tokens
,
image
=
image
,
pre_image
=
pre_image
)[
0
]
dtype
=
logits
.
dtype
lm_logits
=
logits
.
to
(
torch
.
float32
)
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
lm_logits
=
lm_logits
.
to
(
dtype
)
loss
=
loss
.
to
(
dtype
)
return
loss
,
{
'loss'
:
loss
}
from
model.blip2
import
BlipImageEvalProcessor
from
torch.utils.data
import
Dataset
import
json
from
PIL
import
Image
class
FewShotDataset
(
Dataset
):
def
__init__
(
self
,
path
,
processor
,
tokenizer
,
args
):
max_seq_length
=
args
.
max_source_length
+
args
.
max_target_length
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
data
=
json
.
load
(
f
)
self
.
images
=
[]
self
.
input_ids
=
[]
self
.
labels
=
[]
for
item
in
data
:
image
=
processor
(
Image
.
open
(
item
[
'img'
]).
convert
(
'RGB'
))
input0
=
tokenizer
.
encode
(
"<img>"
,
add_special_tokens
=
False
)
input1
=
[
tokenizer
.
pad_token_id
]
*
args
.
image_length
input2
=
tokenizer
.
encode
(
"</img>问:"
+
item
[
'prompt'
]
+
"
\n
答:"
,
add_special_tokens
=
False
)
a_ids
=
sum
([
input0
,
input1
,
input2
],
[])
b_ids
=
tokenizer
.
encode
(
text
=
item
[
'label'
],
add_special_tokens
=
False
)
if
len
(
a_ids
)
>
args
.
max_source_length
-
1
:
a_ids
=
a_ids
[:
args
.
max_source_length
-
1
]
if
len
(
b_ids
)
>
args
.
max_target_length
-
2
:
b_ids
=
b_ids
[:
args
.
max_target_length
-
2
]
pre_image
=
len
(
input0
)
input_ids
=
tokenizer
.
build_inputs_with_special_tokens
(
a_ids
,
b_ids
)
context_length
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
mask_position
=
context_length
-
1
labels
=
[
-
100
]
*
context_length
+
input_ids
[
mask_position
+
1
:]
pad_len
=
max_seq_length
-
len
(
input_ids
)
input_ids
=
input_ids
+
[
tokenizer
.
pad_token_id
]
*
pad_len
labels
=
labels
+
[
tokenizer
.
pad_token_id
]
*
pad_len
if
args
.
ignore_pad_token_for_loss
:
labels
=
[(
l
if
l
!=
tokenizer
.
pad_token_id
else
-
100
)
for
l
in
labels
]
self
.
images
.
append
(
image
)
self
.
input_ids
.
append
(
input_ids
)
self
.
labels
.
append
(
labels
)
self
.
pre_image
=
pre_image
def
__len__
(
self
):
return
len
(
self
.
images
)
def
__getitem__
(
self
,
idx
):
return
{
"image"
:
self
.
images
[
idx
],
"input_ids"
:
self
.
input_ids
[
idx
],
"labels"
:
self
.
labels
[
idx
],
"pre_image"
:
self
.
pre_image
}
def
create_dataset_function
(
path
,
args
):
tokenizer
=
get_tokenizer
(
args
)
image_processor
=
BlipImageEvalProcessor
(
224
)
dataset
=
FewShotDataset
(
path
,
image_processor
,
tokenizer
,
args
)
return
dataset
if
__name__
==
'__main__'
:
py_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
)
py_parser
.
add_argument
(
'--max_source_length'
,
type
=
int
)
py_parser
.
add_argument
(
'--max_target_length'
,
type
=
int
)
py_parser
.
add_argument
(
'--ignore_pad_token_for_loss'
,
type
=
bool
,
default
=
True
)
# py_parser.add_argument('--old_checkpoint', action="store_true")
py_parser
.
add_argument
(
'--source_prefix'
,
type
=
str
,
default
=
""
)
py_parser
=
FineTuneVisualGLMModel
.
add_model_specific_args
(
py_parser
)
known
,
args_list
=
py_parser
.
parse_known_args
()
args
=
get_args
(
args_list
)
args
=
argparse
.
Namespace
(
**
vars
(
args
),
**
vars
(
known
))
model_type
=
'visualglm-6b'
model
,
args
=
FineTuneVisualGLMModel
.
from_pretrained
(
model_type
,
args
)
tokenizer
=
get_tokenizer
(
args
)
label_pad_token_id
=
-
100
if
args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
def
data_collator
(
examples
):
for
example
in
examples
:
example
[
'input_ids'
]
=
torch
.
tensor
(
example
[
'input_ids'
],
dtype
=
torch
.
long
)
example
[
'labels'
]
=
torch
.
tensor
(
example
[
'labels'
],
dtype
=
torch
.
long
)
ret
=
{
'input_ids'
:
torch
.
stack
([
example
[
'input_ids'
]
for
example
in
examples
]),
'labels'
:
torch
.
stack
([
example
[
'labels'
]
for
example
in
examples
]),
'image'
:
torch
.
stack
([
example
[
'image'
]
for
example
in
examples
]),
'pre_image'
:
example
[
'pre_image'
]
}
return
ret
training_main
(
args
,
model_cls
=
model
,
forward_step_function
=
forward_step
,
create_dataset_function
=
create_dataset_function
,
collate_fn
=
data_collator
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment