Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ChatGLM-6B_pytorch
Commits
92060602
Commit
92060602
authored
Sep 14, 2023
by
zhaoying1
Browse files
added lora training
parent
9f19b2a5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
415 additions
and
12 deletions
+415
-12
README.md
README.md
+11
-1
infer_lora.py
infer_lora.py
+23
-0
ptuning/arguments.py
ptuning/arguments.py
+3
-0
ptuning/lora_train.sh
ptuning/lora_train.sh
+25
-0
ptuning/main.py
ptuning/main.py
+43
-11
ptuning/trainer_lora.py
ptuning/trainer_lora.py
+88
-0
ptuning/trainer_seq2seq.py
ptuning/trainer_seq2seq.py
+222
-0
No files found.
README.md
View file @
92060602
...
...
@@ -124,7 +124,17 @@ Hugging Face模型下载地址:
bash evaluate_ft.sh
```
### Results
## LoRA 微调训练
### 单机多卡训练
```
cd ptuning
bash lora_train.sh
```
### LoRA推理
```
python infer_lora.py
```
## Results
-
训练Loss
<div
align=
"center"
>
<img
src=
"./ptuning/media/6B_ds_ft_bs32_accum1_4cards_zero3_5e-5.jpg"
width=
"400"
height=
"300"
>
...
...
infer_lora.py
0 → 100644
View file @
92060602
from
transformers
import
AutoTokenizer
,
AutoModel
from
peft
import
PeftModel
,
PeftConfig
import
torch
import
os
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'1'
# 原始的模型路径
model_name_or_path
=
"/chatglm/ChatGLM2-6B-main/ChatGLM2-6B-main/zero_nlp-main/pretrained_model"
# 训练后的lora保存的路径
peft_model_id
=
"output-chatglm1/adgen-chatglm2-6b-lora_version/checkpoint-2"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
model
=
AutoModel
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
,
device_map
=
'auto'
,
torch_dtype
=
torch
.
float16
)
# .half().cuda()
model
=
PeftModel
.
from_pretrained
(
model
,
peft_model_id
)
model
=
model
.
eval
()
response
,
history
=
model
.
chat
(
tokenizer
,
"类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞"
,
history
=
[])
print
(
response
)
\ No newline at end of file
ptuning/arguments.py
View file @
92060602
...
...
@@ -59,6 +59,9 @@ class ModelArguments:
prefix_projection
:
bool
=
field
(
default
=
False
)
lora_r
:
Optional
[
int
]
=
field
(
default
=
None
)
@
dataclass
...
...
ptuning/lora_train.sh
0 → 100644
View file @
92060602
export
CUDA_VISIBLE_DEVICES
=
1,2,3,4
deepspeed main.py
\
--deepspeed
deepspeed.json
\
--do_train
\
--train_file
AdvertiseGen/train.json
\
--validation_file
AdvertiseGen/dev.json
\
--preprocessing_num_workers
10
\
--prompt_column
content
\
--response_column
summary
\
--overwrite_cache
\
--model_name_or_path
THUDM/chatglm-6b
\
--output_dir
./output_lora/adgen-chatglm2-6b-lora_version
\
--overwrite_output_dir
\
--max_source_length
64
\
--max_target_length
64
\
--per_device_train_batch_size
4
\
--per_device_eval_batch_size
1
\
--gradient_accumulation_steps
16
\
--predict_with_generate
\
--max_steps
10
\
--logging_steps
10
\
--save_steps
500
\
--learning_rate
2e-5
\
--lora_r
32
ptuning/main.py
View file @
92060602
...
...
@@ -42,7 +42,7 @@ from transformers import (
Seq2SeqTrainingArguments
,
set_seed
,
)
from
trainer_seq2seq
import
Seq2SeqTrainer
from
trainer_seq2seq
import
Seq2SeqTrainer
,
Seq2SeqTrainerLora
from
arguments
import
ModelArguments
,
DataTrainingArguments
...
...
@@ -132,6 +132,26 @@ def main():
# P-tuning v2
model
=
model
.
half
()
model
.
transformer
.
prefix_encoder
.
float
()
elif
model_args
.
lora_r
is
not
None
:
from
peft
import
LoraConfig
,
get_peft_model
LORA_R
=
model_args
.
lora_r
LORA_ALPHA
=
16
LORA_DROPOUT
=
0.05
TARGET_MODULES
=
[
"query_key_value"
,
]
config
=
LoraConfig
(
r
=
LORA_R
,
lora_alpha
=
LORA_ALPHA
,
target_modules
=
TARGET_MODULES
,
lora_dropout
=
LORA_DROPOUT
,
bias
=
"none"
,
task_type
=
"CAUSAL_LM"
,
)
model
=
model
.
to
(
torch
.
float16
)
model
=
get_peft_model
(
model
,
config
)
model
.
print_trainable_parameters
()
else
:
# Finetune
model
=
model
.
float
()
...
...
@@ -348,16 +368,28 @@ def main():
data_args
.
num_beams
if
data_args
.
num_beams
is
not
None
else
training_args
.
generation_num_beams
)
# Initialize our Trainer
trainer
=
Seq2SeqTrainer
(
model
=
model
,
args
=
training_args
,
train_dataset
=
train_dataset
if
training_args
.
do_train
else
None
,
eval_dataset
=
eval_dataset
if
training_args
.
do_eval
else
None
,
tokenizer
=
tokenizer
,
data_collator
=
data_collator
,
compute_metrics
=
compute_metrics
if
training_args
.
predict_with_generate
else
None
,
save_prefixencoder
=
model_args
.
pre_seq_len
is
not
None
)
if
model_args
.
lora_r
is
None
:
trainer
=
Seq2SeqTrainer
(
model
=
model
,
args
=
training_args
,
train_dataset
=
train_dataset
if
training_args
.
do_train
else
None
,
eval_dataset
=
eval_dataset
if
training_args
.
do_eval
else
None
,
tokenizer
=
tokenizer
,
data_collator
=
data_collator
,
compute_metrics
=
compute_metrics
if
training_args
.
predict_with_generate
else
None
,
save_prefixencoder
=
model_args
.
pre_seq_len
is
not
None
)
else
:
trainer
=
Seq2SeqTrainerLora
(
model
=
model
,
args
=
training_args
,
train_dataset
=
train_dataset
if
training_args
.
do_train
else
None
,
eval_dataset
=
eval_dataset
if
training_args
.
do_eval
else
None
,
tokenizer
=
tokenizer
,
data_collator
=
data_collator
,
compute_metrics
=
compute_metrics
if
training_args
.
predict_with_generate
else
None
,
save_lora_model
=
True
if
model_args
.
lora_r
is
not
None
else
False
)
# Training
if
training_args
.
do_train
:
...
...
ptuning/trainer_lora.py
0 → 100644
View file @
92060602
# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import
os
from
typing
import
Optional
from
transformers
import
Trainer
import
torch
from
transformers.modeling_utils
import
PreTrainedModel
,
unwrap_model
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
WEIGHTS_NAME
=
"pytorch_model.bin"
TRAINING_ARGS_NAME
=
"training_args.bin"
class
PrefixTrainer
(
Trainer
):
def
__init__
(
self
,
*
args
,
save_changed
=
False
,
save_lora_model
=
False
,
**
kwargs
):
self
.
save_changed
=
save_changed
self
.
save_lora_model
=
save_lora_model
super
().
__init__
(
*
args
,
**
kwargs
)
def
_save_lora
(
self
,
output_dir
:
Optional
[
str
]
=
None
)
->
None
:
logger
.
info
(
f
"Saving LORA model checkpoint to
{
output_dir
}
"
)
self
.
model
.
save_pretrained
(
output_dir
)
def
_save
(
self
,
output_dir
:
Optional
[
str
]
=
None
,
state_dict
=
None
):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir
=
output_dir
if
output_dir
is
not
None
else
self
.
args
.
output_dir
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
save_lora
=
self
.
save_lora_model
if
save_lora
:
self
.
_save_lora
(
output_dir
)
else
:
logger
.
info
(
f
"Saving model checkpoint to
{
output_dir
}
"
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if
not
isinstance
(
self
.
model
,
PreTrainedModel
):
if
isinstance
(
unwrap_model
(
self
.
model
),
PreTrainedModel
):
if
state_dict
is
None
:
state_dict
=
self
.
model
.
state_dict
()
unwrap_model
(
self
.
model
).
save_pretrained
(
output_dir
,
state_dict
=
state_dict
)
else
:
logger
.
info
(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if
state_dict
is
None
:
state_dict
=
self
.
model
.
state_dict
()
torch
.
save
(
state_dict
,
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
))
else
:
if
self
.
save_changed
:
print
(
"Saving PrefixEncoder"
)
state_dict
=
self
.
model
.
state_dict
()
filtered_state_dict
=
{}
for
k
,
v
in
self
.
model
.
named_parameters
():
if
v
.
requires_grad
:
filtered_state_dict
[
k
]
=
state_dict
[
k
]
self
.
model
.
save_pretrained
(
output_dir
,
state_dict
=
filtered_state_dict
)
else
:
print
(
"Saving the whole model"
)
self
.
model
.
save_pretrained
(
output_dir
,
state_dict
=
state_dict
)
if
self
.
tokenizer
is
not
None
:
self
.
tokenizer
.
save_pretrained
(
output_dir
)
# Good practice: save your training arguments together with the trained model
torch
.
save
(
self
.
args
,
os
.
path
.
join
(
output_dir
,
TRAINING_ARGS_NAME
))
ptuning/trainer_seq2seq.py
View file @
92060602
...
...
@@ -20,6 +20,7 @@ from torch.utils.data import Dataset
from
transformers.deepspeed
import
is_deepspeed_zero3_enabled
from
trainer
import
Trainer
from
trainer_lora
import
PrefixTrainer
from
transformers.trainer_utils
import
PredictionOutput
from
transformers.utils
import
logging
...
...
@@ -245,3 +246,224 @@ class Seq2SeqTrainer(Trainer):
)
padded_tensor
[:,
:
tensor
.
shape
[
-
1
]]
=
tensor
return
padded_tensor
class
Seq2SeqTrainerLora
(
PrefixTrainer
):
def
evaluate
(
self
,
eval_dataset
:
Optional
[
Dataset
]
=
None
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
metric_key_prefix
:
str
=
"eval"
,
**
gen_kwargs
)
->
Dict
[
str
,
float
]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init `compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is `"eval"` (default)
max_length (`int`, *optional*):
The maximum target length to use when predicting with the generate method.
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
gen_kwargs:
Additional `generate` specific kwargs.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
gen_kwargs
=
gen_kwargs
.
copy
()
if
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
:
gen_kwargs
[
"max_length"
]
=
self
.
args
.
generation_max_length
gen_kwargs
[
"num_beams"
]
=
(
gen_kwargs
[
"num_beams"
]
if
gen_kwargs
.
get
(
"num_beams"
)
is
not
None
else
self
.
args
.
generation_num_beams
)
self
.
_gen_kwargs
=
gen_kwargs
return
super
().
evaluate
(
eval_dataset
,
ignore_keys
=
ignore_keys
,
metric_key_prefix
=
metric_key_prefix
)
def
predict
(
self
,
test_dataset
:
Dataset
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
metric_key_prefix
:
str
=
"test"
,
**
gen_kwargs
)
->
PredictionOutput
:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in `evaluate()`.
Args:
test_dataset (`Dataset`):
Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. Has to implement the method `__len__`
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is `"eval"` (default)
max_length (`int`, *optional*):
The maximum target length to use when predicting with the generate method.
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
gen_kwargs:
Additional `generate` specific kwargs.
<Tip>
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
</Tip>
Returns: *NamedTuple* A namedtuple with the following keys:
- predictions (`np.ndarray`): The predictions on `test_dataset`.
- label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels).
"""
gen_kwargs
=
gen_kwargs
.
copy
()
if
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
:
gen_kwargs
[
"max_length"
]
=
self
.
args
.
generation_max_length
gen_kwargs
[
"num_beams"
]
=
(
gen_kwargs
[
"num_beams"
]
if
gen_kwargs
.
get
(
"num_beams"
)
is
not
None
else
self
.
args
.
generation_num_beams
)
self
.
_gen_kwargs
=
gen_kwargs
return
super
().
predict
(
test_dataset
,
ignore_keys
=
ignore_keys
,
metric_key_prefix
=
metric_key_prefix
)
def
prediction_step
(
self
,
model
:
nn
.
Module
,
inputs
:
Dict
[
str
,
Union
[
torch
.
Tensor
,
Any
]],
prediction_loss_only
:
bool
,
ignore_keys
:
Optional
[
List
[
str
]]
=
None
,
)
->
Tuple
[
Optional
[
float
],
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
labels (each being optional).
"""
if
not
self
.
args
.
predict_with_generate
or
prediction_loss_only
:
return
super
().
prediction_step
(
model
,
inputs
,
prediction_loss_only
=
prediction_loss_only
,
ignore_keys
=
ignore_keys
)
has_labels
=
"labels"
in
inputs
inputs
=
self
.
_prepare_inputs
(
inputs
)
# XXX: adapt synced_gpus for fairscale as well
gen_kwargs
=
self
.
_gen_kwargs
.
copy
()
if
gen_kwargs
.
get
(
"max_length"
)
is
None
and
gen_kwargs
.
get
(
"max_new_tokens"
)
is
None
:
gen_kwargs
[
"max_length"
]
=
self
.
model
.
config
.
max_length
gen_kwargs
[
"num_beams"
]
=
(
gen_kwargs
[
"num_beams"
]
if
gen_kwargs
.
get
(
"num_beams"
)
is
not
None
else
self
.
model
.
config
.
num_beams
)
default_synced_gpus
=
True
if
is_deepspeed_zero3_enabled
()
else
False
gen_kwargs
[
"synced_gpus"
]
=
(
gen_kwargs
[
"synced_gpus"
]
if
gen_kwargs
.
get
(
"synced_gpus"
)
is
not
None
else
default_synced_gpus
)
if
"attention_mask"
in
inputs
:
gen_kwargs
[
"attention_mask"
]
=
inputs
.
get
(
"attention_mask"
,
None
)
if
"position_ids"
in
inputs
:
gen_kwargs
[
"position_ids"
]
=
inputs
.
get
(
"position_ids"
,
None
)
if
"global_attention_mask"
in
inputs
:
gen_kwargs
[
"global_attention_mask"
]
=
inputs
.
get
(
"global_attention_mask"
,
None
)
# prepare generation inputs
# some encoder-decoder models can have varying encoder's and thus
# varying model input names
if
hasattr
(
self
.
model
,
"encoder"
)
and
self
.
model
.
encoder
.
main_input_name
!=
self
.
model
.
main_input_name
:
generation_inputs
=
inputs
[
self
.
model
.
encoder
.
main_input_name
]
else
:
generation_inputs
=
inputs
[
self
.
model
.
main_input_name
]
gen_kwargs
[
"input_ids"
]
=
generation_inputs
generated_tokens
=
self
.
model
.
generate
(
**
gen_kwargs
)
generated_tokens
=
generated_tokens
[:,
generation_inputs
.
size
()[
-
1
]:]
# in case the batch is shorter than max length, the output should be padded
if
gen_kwargs
.
get
(
"max_length"
)
is
not
None
and
generated_tokens
.
shape
[
-
1
]
<
gen_kwargs
[
"max_length"
]:
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
,
gen_kwargs
[
"max_length"
])
elif
gen_kwargs
.
get
(
"max_new_tokens"
)
is
not
None
and
generated_tokens
.
shape
[
-
1
]
<
(
gen_kwargs
[
"max_new_tokens"
]
+
1
):
generated_tokens
=
self
.
_pad_tensors_to_max_len
(
generated_tokens
,
gen_kwargs
[
"max_new_tokens"
]
+
1
)
loss
=
None
if
self
.
args
.
prediction_loss_only
:
return
(
loss
,
None
,
None
)
if
has_labels
:
labels
=
inputs
[
"labels"
]
if
gen_kwargs
.
get
(
"max_length"
)
is
not
None
and
labels
.
shape
[
-
1
]
<
gen_kwargs
[
"max_length"
]:
labels
=
self
.
_pad_tensors_to_max_len
(
labels
,
gen_kwargs
[
"max_length"
])
elif
gen_kwargs
.
get
(
"max_new_tokens"
)
is
not
None
and
labels
.
shape
[
-
1
]
<
(
gen_kwargs
[
"max_new_tokens"
]
+
1
):
labels
=
self
.
_pad_tensors_to_max_len
(
labels
,
(
gen_kwargs
[
"max_new_tokens"
]
+
1
))
else
:
labels
=
None
return
(
loss
,
generated_tokens
,
labels
)
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
):
if
self
.
tokenizer
is
not
None
and
hasattr
(
self
.
tokenizer
,
"pad_token_id"
):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id
=
(
self
.
tokenizer
.
pad_token_id
if
self
.
tokenizer
.
pad_token_id
is
not
None
else
self
.
tokenizer
.
eos_token_id
)
else
:
if
self
.
model
.
config
.
pad_token_id
is
not
None
:
pad_token_id
=
self
.
model
.
config
.
pad_token_id
else
:
raise
ValueError
(
"Pad_token_id must be set in the configuration of the model, in order to pad tensors"
)
padded_tensor
=
pad_token_id
*
torch
.
ones
(
(
tensor
.
shape
[
0
],
max_length
),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
padded_tensor
[:,
:
tensor
.
shape
[
-
1
]]
=
tensor
return
padded_tensor
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment