Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
XrayGLM_pytorch
Commits
4c497230
Unverified
Commit
4c497230
authored
May 25, 2023
by
MPU王荣胜
Committed by
GitHub
May 25, 2023
Browse files
add finetune
parent
62f5e989
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
188 additions
and
0 deletions
+188
-0
finetune_XrayGLM.py
finetune_XrayGLM.py
+188
-0
No files found.
finetune_XrayGLM.py
0 → 100644
View file @
4c497230
import
os
import
torch
import
argparse
from
sat
import
mpu
,
get_args
,
get_tokenizer
from
sat.training.deepspeed_training
import
training_main
from
model
import
VisualGLMModel
from
sat.model.finetune
import
PTuningV2Mixin
from
sat.model.finetune.lora_mixin
import
LoraMixin
class
FineTuneVisualGLMModel
(
VisualGLMModel
):
def
__init__
(
self
,
args
,
transformer
=
None
,
parallel_output
=
True
,
**
kw_args
):
super
().
__init__
(
args
,
transformer
=
transformer
,
parallel_output
=
parallel_output
,
**
kw_args
)
if
args
.
use_ptuning
:
self
.
add_mixin
(
"ptuning"
,
PTuningV2Mixin
(
args
.
num_layers
,
args
.
hidden_size
//
args
.
num_attention_heads
,
args
.
num_attention_heads
,
args
.
pre_seq_len
))
if
args
.
use_lora
:
# If you use lora on other "normal" Transformer, just use it with head_first=False (by default)
self
.
add_mixin
(
"lora"
,
LoraMixin
(
args
.
num_layers
,
args
.
lora_rank
,
head_first
=
True
,
num_attention_heads
=
args
.
num_attention_heads
,
hidden_size_per_attention_head
=
args
.
hidden_size
//
args
.
num_attention_heads
,
layer_range
=
list
(
range
(
0
,
28
,
14
))),
reinit
=
True
)
# self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
self
.
args
=
args
@
classmethod
def
add_model_specific_args
(
cls
,
parser
):
group
=
parser
.
add_argument_group
(
'VisualGLM-finetune'
,
'VisualGLM finetune Configurations'
)
group
.
add_argument
(
'--pre_seq_len'
,
type
=
int
,
default
=
8
)
group
.
add_argument
(
'--lora_rank'
,
type
=
int
,
default
=
10
)
group
.
add_argument
(
'--use_ptuning'
,
action
=
"store_true"
)
group
.
add_argument
(
'--use_lora'
,
action
=
"store_true"
)
return
super
().
add_model_specific_args
(
parser
)
def
disable_untrainable_params
(
self
):
enable
=
[]
if
self
.
args
.
use_ptuning
:
enable
.
extend
([
'ptuning'
])
if
self
.
args
.
use_lora
:
enable
.
extend
([
'matrix_A'
,
'matrix_B'
])
for
n
,
p
in
self
.
named_parameters
():
flag
=
False
for
e
in
enable
:
if
e
.
lower
()
in
n
.
lower
():
flag
=
True
break
if
not
flag
:
p
.
requires_grad_
(
False
)
else
:
print
(
n
)
def
get_batch
(
data_iterator
,
args
,
timers
):
# Items and their type.
keys
=
[
'input_ids'
,
'labels'
]
datatype
=
torch
.
int64
# Broadcast data.
timers
(
'data loader'
).
start
()
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
timers
(
'data loader'
).
stop
()
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
data_i
=
mpu
.
broadcast_data
([
'image'
],
data
,
torch
.
float32
)
# Unpack.
tokens
=
data_b
[
'input_ids'
].
long
()
labels
=
data_b
[
'labels'
].
long
()
img
=
data_i
[
'image'
]
if
args
.
fp16
:
img
=
img
.
half
()
return
tokens
,
labels
,
img
,
data
[
'pre_image'
]
from
torch.nn
import
CrossEntropyLoss
def
forward_step
(
data_iterator
,
model
,
args
,
timers
):
"""Forward step."""
# Get the batch.
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
image
,
pre_image
=
get_batch
(
data_iterator
,
args
,
timers
)
timers
(
'batch generator'
).
stop
()
logits
=
model
(
input_ids
=
tokens
,
image
=
image
,
pre_image
=
pre_image
)[
0
]
dtype
=
logits
.
dtype
lm_logits
=
logits
.
to
(
torch
.
float32
)
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
100
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
lm_logits
=
lm_logits
.
to
(
dtype
)
loss
=
loss
.
to
(
dtype
)
return
loss
,
{
'loss'
:
loss
}
from
model.blip2
import
BlipImageEvalProcessor
from
torch.utils.data
import
Dataset
import
json
from
PIL
import
Image
class
FewShotDataset
(
Dataset
):
def
__init__
(
self
,
path
,
processor
,
tokenizer
,
args
):
max_seq_length
=
args
.
max_source_length
+
args
.
max_target_length
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
data
=
json
.
load
(
f
)
self
.
images
=
[]
self
.
input_ids
=
[]
self
.
labels
=
[]
for
item
in
data
:
image
=
processor
(
Image
.
open
(
item
[
'img'
]).
convert
(
'RGB'
))
input0
=
tokenizer
.
encode
(
"<img>"
,
add_special_tokens
=
False
)
input1
=
[
tokenizer
.
pad_token_id
]
*
args
.
image_length
input2
=
tokenizer
.
encode
(
"</img>问:"
+
item
[
'prompt'
]
+
"
\n
答:"
,
add_special_tokens
=
False
)
a_ids
=
sum
([
input0
,
input1
,
input2
],
[])
b_ids
=
tokenizer
.
encode
(
text
=
item
[
'label'
],
add_special_tokens
=
False
)
if
len
(
a_ids
)
>
args
.
max_source_length
-
1
:
a_ids
=
a_ids
[:
args
.
max_source_length
-
1
]
if
len
(
b_ids
)
>
args
.
max_target_length
-
2
:
b_ids
=
b_ids
[:
args
.
max_target_length
-
2
]
pre_image
=
len
(
input0
)
input_ids
=
tokenizer
.
build_inputs_with_special_tokens
(
a_ids
,
b_ids
)
context_length
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
mask_position
=
context_length
-
1
labels
=
[
-
100
]
*
context_length
+
input_ids
[
mask_position
+
1
:]
pad_len
=
max_seq_length
-
len
(
input_ids
)
input_ids
=
input_ids
+
[
tokenizer
.
pad_token_id
]
*
pad_len
labels
=
labels
+
[
tokenizer
.
pad_token_id
]
*
pad_len
if
args
.
ignore_pad_token_for_loss
:
labels
=
[(
l
if
l
!=
tokenizer
.
pad_token_id
else
-
100
)
for
l
in
labels
]
self
.
images
.
append
(
image
)
self
.
input_ids
.
append
(
input_ids
)
self
.
labels
.
append
(
labels
)
self
.
pre_image
=
pre_image
def
__len__
(
self
):
return
len
(
self
.
images
)
def
__getitem__
(
self
,
idx
):
return
{
"image"
:
self
.
images
[
idx
],
"input_ids"
:
self
.
input_ids
[
idx
],
"labels"
:
self
.
labels
[
idx
],
"pre_image"
:
self
.
pre_image
}
def
create_dataset_function
(
path
,
args
):
tokenizer
=
get_tokenizer
(
args
)
image_processor
=
BlipImageEvalProcessor
(
224
)
dataset
=
FewShotDataset
(
path
,
image_processor
,
tokenizer
,
args
)
return
dataset
if
__name__
==
'__main__'
:
py_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
)
py_parser
.
add_argument
(
'--max_source_length'
,
type
=
int
)
py_parser
.
add_argument
(
'--max_target_length'
,
type
=
int
)
py_parser
.
add_argument
(
'--ignore_pad_token_for_loss'
,
type
=
bool
,
default
=
True
)
# py_parser.add_argument('--old_checkpoint', action="store_true")
py_parser
.
add_argument
(
'--source_prefix'
,
type
=
str
,
default
=
""
)
py_parser
=
FineTuneVisualGLMModel
.
add_model_specific_args
(
py_parser
)
known
,
args_list
=
py_parser
.
parse_known_args
()
args
=
get_args
(
args_list
)
args
=
argparse
.
Namespace
(
**
vars
(
args
),
**
vars
(
known
))
model_type
=
'visualglm-6b'
model
,
args
=
FineTuneVisualGLMModel
.
from_pretrained
(
model_type
,
args
)
tokenizer
=
get_tokenizer
(
args
)
label_pad_token_id
=
-
100
if
args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
def
data_collator
(
examples
):
for
example
in
examples
:
example
[
'input_ids'
]
=
torch
.
tensor
(
example
[
'input_ids'
],
dtype
=
torch
.
long
)
example
[
'labels'
]
=
torch
.
tensor
(
example
[
'labels'
],
dtype
=
torch
.
long
)
ret
=
{
'input_ids'
:
torch
.
stack
([
example
[
'input_ids'
]
for
example
in
examples
]),
'labels'
:
torch
.
stack
([
example
[
'labels'
]
for
example
in
examples
]),
'image'
:
torch
.
stack
([
example
[
'image'
]
for
example
in
examples
]),
'pre_image'
:
example
[
'pre_image'
]
}
return
ret
training_main
(
args
,
model_cls
=
model
,
forward_step_function
=
forward_step
,
create_dataset_function
=
create_dataset_function
,
collate_fn
=
data_collator
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment