Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ChatGLM-6B_pytorch
Commits
95306eb9
Commit
95306eb9
authored
Jun 07, 2023
by
yuguo960516yuguo
Browse files
1.0
parent
a10b8407
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
14635 additions
and
2 deletions
+14635
-2
README.md
README.md
+65
-2
model.properties
model.properties
+8
-0
simple-pretrain/The-Lord-of-the-Rings-1.txt
simple-pretrain/The-Lord-of-the-Rings-1.txt
+8091
-0
simple-pretrain/convert.py
simple-pretrain/convert.py
+25
-0
simple-pretrain/modeling_chatglm.py
simple-pretrain/modeling_chatglm.py
+1450
-0
simple-pretrain/ptuning/The-Lord-of-the-Rings-1.json
simple-pretrain/ptuning/The-Lord-of-the-Rings-1.json
+4508
-0
simple-pretrain/ptuning/deepspeed.json
simple-pretrain/ptuning/deepspeed.json
+31
-0
simple-pretrain/ptuning/ds_pretrain.sh
simple-pretrain/ptuning/ds_pretrain.sh
+26
-0
simple-pretrain/ptuning/main.py
simple-pretrain/ptuning/main.py
+431
-0
No files found.
README.md
View file @
95306eb9
...
...
@@ -102,10 +102,73 @@ ChatGLM-6B 是清华大学开源的开源的、支持中英双语的对话语言
运行如下命令:
python cli_demo.py
程序会在命令行中进行交互式的对话,在命令行中输入指示并回车即可生成回复,输入 clear 可以清空对话历史,输入 stop 终止程序。
## 重新 pretrain
由于当前的
[
GLM-130B
](
https://github.com/THUDM/GLM-130B#news
)
与 ChatGLM 的模型结构非常类似,所以对于有训练 GLM-130B 的用户来说,可以通过修改 ChatGLM 的 config.json 使用堆参数的方式将参数量达到130B。该项目为了满足用户对 ChatGLM 重新 pretrain 的需求,继续添加了 simple-pretrain 目录,旨在提供一种改动最小的 pretrain 示例。pretrain步骤如下:
1.
将 simple-pretrain/ptuning 下的文件移到本 ptuning 目录下,替换相关文件
2.
将 modeling_chatglm.py 移到
[
ChatGLM 模型
](
https://huggingface.co/THUDM/chatglm-6b
)
所在目录替换原始 modeling_chatglm.py
3.
在本 ptuning 目录下:
```
bash ds_pretrain.sh
```
说明:convert.py 可以将原始的txt数据转换成 chatglm 可用的 json 形式的数据集格式。该示例使用指环王1书籍作为预训练数据集。
### 实验设置
```
LR=1e-5
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
HIP_VISIBLE_DEVICES=0,1,2,3 deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
--deepspeed deepspeed.json \
--do_train \
--train_file The-Lord-of-the-Rings-1.json \
--prompt_column prompt \
--response_column response \
--overwrite_cache \
--model_name_or_path THUDM/chatglm-6b \
--output_dir ./output/pretrain \
--overwrite_output_dir \
--max_source_length 8 \
--max_target_length 128 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--predict_with_generate \
--max_steps 2000 \
--logging_steps 5 \
--save_steps 1000 \
--learning_rate $LR \
--fp16
```
### 训练loss收敛情况
由于该示例预训练数据集较小,loss会降的至较低水平到0.1左右。

## 强化学习(RLHF)微调方案
目前在 DCU 上 ChatGLM 使用强化学习微调有两种方案:
-
使用 Lora,只更新低秩适配层,可以直接参考项目:https://github.com/hiyouga/ChatGLM-Efficient-Tuning/blob/main/examples/covid_doctor.md
-
使用 DeepSpeed-Chat 方案全参微调,目前已经适配完成,欢迎尝试:https://github.com/yuguo-Jack/ChatGLM-6B-in-DeepSpeed-Chat
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/chatglm
## 参考
[
THUDM/ChatGLM-6B
](
https://github.com/THUDM/ChatGLM-6B/tree/main
)
[
THUDM/ChatGLM-6B
](
https://github.com/THUDM/ChatGLM-6B/tree/main
)
model.properties
0 → 100644
View file @
95306eb9
# 模型名称
modelName
=
ChatGLM-6B
# 模型描述
modelDescription
=
基于Pytorch框架的ChatGLM-6B
# 应用场景(多个标签以英文逗号分割)
appScenario
=
训练,推理,train,inference,nlp,智能聊天助手
# 框架类型(多个标签以英文逗号分割)
frameType
=
Pytorch
simple-pretrain/The-Lord-of-the-Rings-1.txt
0 → 100644
View file @
95306eb9
This diff is collapsed.
Click to expand it.
simple-pretrain/convert.py
0 → 100644
View file @
95306eb9
import
json
def
convert_txt_to_json
(
txt_file
):
json_data
=
[]
with
open
(
txt_file
,
'r'
,
encoding
=
'gb18030'
)
as
file
:
for
line
in
file
:
line
=
line
.
strip
()
if
line
:
prompt
=
line
[
0
]
response
=
line
[
1
:]
json_entry
=
{
'prompt'
:
prompt
,
'response'
:
response
}
json_data
.
append
(
json_entry
)
return
json_data
txt_file
=
'The-Lord-of-the-Rings-1.txt'
# 替换成你的文本文件路径
json_data
=
convert_txt_to_json
(
txt_file
)
# 将JSON数据写入文件
json_file
=
'The-Lord-of-the-Rings-1.json'
# 输出的JSON文件路径
with
open
(
json_file
,
'w'
,
encoding
=
'utf-8'
)
as
file
:
for
entry
in
json_data
:
json
.
dump
(
entry
,
file
,
ensure_ascii
=
False
)
file
.
write
(
'
\n
'
)
simple-pretrain/modeling_chatglm.py
0 → 100644
View file @
95306eb9
This diff is collapsed.
Click to expand it.
simple-pretrain/ptuning/The-Lord-of-the-Rings-1.json
0 → 100644
View file @
95306eb9
This diff is collapsed.
Click to expand it.
simple-pretrain/ptuning/deepspeed.json
0 → 100644
View file @
95306eb9
{
"train_micro_batch_size_per_gpu"
:
"auto"
,
"gradient_accumulation_steps"
:
"auto"
,
"zero_allow_untested_optimizer"
:
true
,
"fp16"
:
{
"enabled"
:
"auto"
,
"loss_scale"
:
0
,
"initial_scale_power"
:
16
,
"loss_scale_window"
:
1000
,
"hysteresis"
:
2
,
"min_loss_scale"
:
1
},
"zero_optimization"
:
{
"stage"
:
3
,
"offload_optimizer"
:
{
"device"
:
"cpu"
,
"pin_memory"
:
true
},
"offload_param"
:
{
"device"
:
"cpu"
,
"pin_memory"
:
true
},
"allgather_partitions"
:
true
,
"allgather_bucket_size"
:
5e8
,
"overlap_comm"
:
true
,
"reduce_scatter"
:
true
,
"reduce_bucket_size"
:
5e8
,
"contiguous_gradients"
:
true
,
"stage3_gather_16bit_weights_on_model_save"
:
true
}
}
simple-pretrain/ptuning/ds_pretrain.sh
0 → 100644
View file @
95306eb9
LR
=
1e-5
MASTER_PORT
=
$(
shuf
-n
1
-i
10000-65535
)
HIP_VISIBLE_DEVICES
=
0,1,2,3 deepspeed
--num_gpus
=
4
--master_port
$MASTER_PORT
main.py
\
--deepspeed
deepspeed.json
\
--do_train
\
--train_file
The-Lord-of-the-Rings-1.json
\
--prompt_column
prompt
\
--response_column
response
\
--overwrite_cache
\
--model_name_or_path
THUDM/chatglm-6b
\
--output_dir
./output/pretrain
\
--overwrite_output_dir
\
--max_source_length
8
\
--max_target_length
128
\
--per_device_train_batch_size
16
\
--per_device_eval_batch_size
4
\
--gradient_accumulation_steps
4
\
--predict_with_generate
\
--max_steps
2000
\
--logging_steps
5
\
--save_steps
1000
\
--learning_rate
$LR
\
--fp16
simple-pretrain/ptuning/main.py
0 → 100644
View file @
95306eb9
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import
logging
import
os
import
sys
import
json
import
numpy
as
np
from
datasets
import
load_dataset
import
jieba
from
rouge_chinese
import
Rouge
from
nltk.translate.bleu_score
import
sentence_bleu
,
SmoothingFunction
import
torch
import
transformers
from
transformers
import
(
AutoConfig
,
AutoModel
,
AutoTokenizer
,
DataCollatorForSeq2Seq
,
HfArgumentParser
,
Seq2SeqTrainingArguments
,
set_seed
,
)
from
trainer_seq2seq
import
Seq2SeqTrainer
from
arguments
import
ModelArguments
,
DataTrainingArguments
logger
=
logging
.
getLogger
(
__name__
)
def
main
():
parser
=
HfArgumentParser
((
ModelArguments
,
DataTrainingArguments
,
Seq2SeqTrainingArguments
))
if
len
(
sys
.
argv
)
==
2
and
sys
.
argv
[
1
].
endswith
(
".json"
):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args
,
data_args
,
training_args
=
parser
.
parse_json_file
(
json_file
=
os
.
path
.
abspath
(
sys
.
argv
[
1
]))
else
:
model_args
,
data_args
,
training_args
=
parser
.
parse_args_into_dataclasses
()
# Setup logging
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
)
if
training_args
.
should_log
:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers
.
utils
.
logging
.
set_verbosity_info
()
log_level
=
training_args
.
get_process_log_level
()
logger
.
setLevel
(
log_level
)
# datasets.utils.logging.set_verbosity(log_level)
transformers
.
utils
.
logging
.
set_verbosity
(
log_level
)
transformers
.
utils
.
logging
.
enable_default_handler
()
transformers
.
utils
.
logging
.
enable_explicit_format
()
# Log on each process the small summary:
logger
.
warning
(
f
"Process rank:
{
training_args
.
local_rank
}
, device:
{
training_args
.
device
}
, n_gpu:
{
training_args
.
n_gpu
}
"
+
f
"distributed training:
{
bool
(
training_args
.
local_rank
!=
-
1
)
}
, 16-bits training:
{
training_args
.
fp16
}
"
)
logger
.
info
(
f
"Training/evaluation parameters
{
training_args
}
"
)
# Set seed before initializing model.
set_seed
(
training_args
.
seed
)
# Load dataset
data_files
=
{}
if
data_args
.
train_file
is
not
None
:
data_files
[
"train"
]
=
data_args
.
train_file
extension
=
data_args
.
train_file
.
split
(
"."
)[
-
1
]
if
data_args
.
validation_file
is
not
None
:
data_files
[
"validation"
]
=
data_args
.
validation_file
extension
=
data_args
.
validation_file
.
split
(
"."
)[
-
1
]
if
data_args
.
test_file
is
not
None
:
data_files
[
"test"
]
=
data_args
.
test_file
extension
=
data_args
.
test_file
.
split
(
"."
)[
-
1
]
raw_datasets
=
load_dataset
(
extension
,
data_files
=
data_files
,
cache_dir
=
model_args
.
cache_dir
,
use_auth_token
=
True
if
model_args
.
use_auth_token
else
None
,
)
# Load pretrained model and tokenizer
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
trust_remote_code
=
True
)
config
.
pre_seq_len
=
model_args
.
pre_seq_len
config
.
prefix_projection
=
model_args
.
prefix_projection
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_args
.
model_name_or_path
,
trust_remote_code
=
True
)
if
model_args
.
ptuning_checkpoint
is
not
None
:
# Evaluation
# Loading extra state dict of prefix encoder
model
=
AutoModel
.
from_pretrained
(
model_args
.
model_name_or_path
,
config
=
config
,
trust_remote_code
=
True
)
prefix_state_dict
=
torch
.
load
(
os
.
path
.
join
(
model_args
.
ptuning_checkpoint
,
"pytorch_model.bin"
))
new_prefix_state_dict
=
{}
for
k
,
v
in
prefix_state_dict
.
items
():
if
k
.
startswith
(
"transformer.prefix_encoder."
):
new_prefix_state_dict
[
k
[
len
(
"transformer.prefix_encoder."
):]]
=
v
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
else
:
# model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
model
=
AutoModel
.
from_config
(
config
=
config
,
trust_remote_code
=
True
)
if
model_args
.
quantization_bit
is
not
None
:
print
(
f
"Quantized to
{
model_args
.
quantization_bit
}
bit"
)
model
=
model
.
quantize
(
model_args
.
quantization_bit
)
if
model_args
.
pre_seq_len
is
not
None
:
# P-tuning v2
model
=
model
.
half
()
model
.
transformer
.
prefix_encoder
.
float
()
else
:
# Finetune
model
=
model
.
float
()
prefix
=
data_args
.
source_prefix
if
data_args
.
source_prefix
is
not
None
else
""
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if
training_args
.
do_train
:
column_names
=
raw_datasets
[
"train"
].
column_names
elif
training_args
.
do_eval
:
column_names
=
raw_datasets
[
"validation"
].
column_names
elif
training_args
.
do_predict
:
column_names
=
raw_datasets
[
"test"
].
column_names
else
:
logger
.
info
(
"There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`."
)
return
# Get the column names for input/target.
prompt_column
=
data_args
.
prompt_column
response_column
=
data_args
.
response_column
history_column
=
data_args
.
history_column
# Temporarily set max_target_length for training.
max_target_length
=
data_args
.
max_target_length
def
preprocess_function_eval
(
examples
):
inputs
,
targets
=
[],
[]
for
i
in
range
(
len
(
examples
[
prompt_column
])):
if
examples
[
prompt_column
][
i
]
and
examples
[
response_column
][
i
]:
query
=
examples
[
prompt_column
][
i
]
if
history_column
is
None
or
len
(
examples
[
history_column
][
i
])
==
0
:
prompt
=
query
else
:
prompt
=
""
history
=
examples
[
history_column
][
i
]
for
turn_idx
,
(
old_query
,
response
)
in
enumerate
(
history
):
prompt
+=
"[Round {}]
\n
问:{}
\n
答:{}
\n
"
.
format
(
turn_idx
,
old_query
,
response
)
prompt
+=
"[Round {}]
\n
问:{}
\n
答:"
.
format
(
len
(
history
),
query
)
inputs
.
append
(
prompt
)
targets
.
append
(
examples
[
response_column
][
i
])
inputs
=
[
prefix
+
inp
for
inp
in
inputs
]
model_inputs
=
tokenizer
(
inputs
,
max_length
=
data_args
.
max_source_length
,
truncation
=
True
,
padding
=
True
)
labels
=
tokenizer
(
text_target
=
targets
,
max_length
=
max_target_length
,
truncation
=
True
)
if
data_args
.
ignore_pad_token_for_loss
:
labels
[
"input_ids"
]
=
[
[(
l
if
l
!=
tokenizer
.
pad_token_id
else
-
100
)
for
l
in
label
]
for
label
in
labels
[
"input_ids"
]
]
model_inputs
[
"labels"
]
=
labels
[
"input_ids"
]
return
model_inputs
def
preprocess_function_train
(
examples
):
max_seq_length
=
data_args
.
max_source_length
+
data_args
.
max_target_length
model_inputs
=
{
"input_ids"
:
[],
"labels"
:
[],
}
for
i
in
range
(
len
(
examples
[
prompt_column
])):
if
examples
[
prompt_column
][
i
]
and
examples
[
response_column
][
i
]:
query
,
answer
=
examples
[
prompt_column
][
i
],
examples
[
response_column
][
i
]
if
history_column
is
None
:
prompt
=
query
else
:
prompt
=
""
history
=
examples
[
history_column
][
i
]
for
turn_idx
,
(
old_query
,
response
)
in
enumerate
(
history
):
prompt
+=
"[Round {}]
\n
问:{}
\n
答:{}
\n
"
.
format
(
turn_idx
,
old_query
,
response
)
prompt
+=
"[Round {}]
\n
问:{}
\n
答:"
.
format
(
len
(
history
),
query
)
prompt
=
prefix
+
prompt
a_ids
=
tokenizer
.
encode
(
text
=
prompt
,
add_special_tokens
=
False
)
b_ids
=
tokenizer
.
encode
(
text
=
answer
,
add_special_tokens
=
False
)
if
len
(
a_ids
)
>
data_args
.
max_source_length
-
1
:
a_ids
=
a_ids
[:
data_args
.
max_source_length
-
1
]
if
len
(
b_ids
)
>
data_args
.
max_target_length
-
2
:
b_ids
=
b_ids
[:
data_args
.
max_target_length
-
2
]
input_ids
=
tokenizer
.
build_inputs_with_special_tokens
(
a_ids
,
b_ids
)
context_length
=
input_ids
.
index
(
tokenizer
.
bos_token_id
)
mask_position
=
context_length
-
1
labels
=
[
-
100
]
*
context_length
+
input_ids
[
mask_position
+
1
:]
pad_len
=
max_seq_length
-
len
(
input_ids
)
input_ids
=
input_ids
+
[
tokenizer
.
pad_token_id
]
*
pad_len
labels
=
labels
+
[
tokenizer
.
pad_token_id
]
*
pad_len
if
data_args
.
ignore_pad_token_for_loss
:
labels
=
[(
l
if
l
!=
tokenizer
.
pad_token_id
else
-
100
)
for
l
in
labels
]
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"labels"
].
append
(
labels
)
return
model_inputs
def
print_dataset_example
(
example
):
print
(
"input_ids"
,
example
[
"input_ids"
])
print
(
"inputs"
,
tokenizer
.
decode
(
example
[
"input_ids"
]))
print
(
"label_ids"
,
example
[
"labels"
])
print
(
"labels"
,
tokenizer
.
decode
(
example
[
"labels"
]))
if
training_args
.
do_train
:
if
"train"
not
in
raw_datasets
:
raise
ValueError
(
"--do_train requires a train dataset"
)
train_dataset
=
raw_datasets
[
"train"
]
if
data_args
.
max_train_samples
is
not
None
:
max_train_samples
=
min
(
len
(
train_dataset
),
data_args
.
max_train_samples
)
train_dataset
=
train_dataset
.
select
(
range
(
max_train_samples
))
with
training_args
.
main_process_first
(
desc
=
"train dataset map pre-processing"
):
train_dataset
=
train_dataset
.
map
(
preprocess_function_train
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
remove_columns
=
column_names
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
desc
=
"Running tokenizer on train dataset"
,
)
print_dataset_example
(
train_dataset
[
0
])
if
training_args
.
do_eval
:
max_target_length
=
data_args
.
val_max_target_length
if
"validation"
not
in
raw_datasets
:
raise
ValueError
(
"--do_eval requires a validation dataset"
)
eval_dataset
=
raw_datasets
[
"validation"
]
if
data_args
.
max_eval_samples
is
not
None
:
max_eval_samples
=
min
(
len
(
eval_dataset
),
data_args
.
max_eval_samples
)
eval_dataset
=
eval_dataset
.
select
(
range
(
max_eval_samples
))
with
training_args
.
main_process_first
(
desc
=
"validation dataset map pre-processing"
):
eval_dataset
=
eval_dataset
.
map
(
preprocess_function_eval
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
remove_columns
=
column_names
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
desc
=
"Running tokenizer on validation dataset"
,
)
print_dataset_example
(
eval_dataset
[
0
])
if
training_args
.
do_predict
:
max_target_length
=
data_args
.
val_max_target_length
if
"test"
not
in
raw_datasets
:
raise
ValueError
(
"--do_predict requires a test dataset"
)
predict_dataset
=
raw_datasets
[
"test"
]
if
data_args
.
max_predict_samples
is
not
None
:
max_predict_samples
=
min
(
len
(
predict_dataset
),
data_args
.
max_predict_samples
)
predict_dataset
=
predict_dataset
.
select
(
range
(
max_predict_samples
))
with
training_args
.
main_process_first
(
desc
=
"prediction dataset map pre-processing"
):
predict_dataset
=
predict_dataset
.
map
(
preprocess_function_eval
,
batched
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
remove_columns
=
column_names
,
load_from_cache_file
=
not
data_args
.
overwrite_cache
,
desc
=
"Running tokenizer on prediction dataset"
,
)
print_dataset_example
(
predict_dataset
[
0
])
# Data collator
label_pad_token_id
=
-
100
if
data_args
.
ignore_pad_token_for_loss
else
tokenizer
.
pad_token_id
data_collator
=
DataCollatorForSeq2Seq
(
tokenizer
,
model
=
model
,
label_pad_token_id
=
label_pad_token_id
,
pad_to_multiple_of
=
None
,
padding
=
False
)
# Metric
def
compute_metrics
(
eval_preds
):
preds
,
labels
=
eval_preds
if
isinstance
(
preds
,
tuple
):
preds
=
preds
[
0
]
decoded_preds
=
tokenizer
.
batch_decode
(
preds
,
skip_special_tokens
=
True
)
if
data_args
.
ignore_pad_token_for_loss
:
# Replace -100 in the labels as we can't decode them.
labels
=
np
.
where
(
labels
!=
-
100
,
labels
,
tokenizer
.
pad_token_id
)
decoded_labels
=
tokenizer
.
batch_decode
(
labels
,
skip_special_tokens
=
True
)
score_dict
=
{
"rouge-1"
:
[],
"rouge-2"
:
[],
"rouge-l"
:
[],
"bleu-4"
:
[]
}
for
pred
,
label
in
zip
(
decoded_preds
,
decoded_labels
):
hypothesis
=
list
(
jieba
.
cut
(
pred
))
reference
=
list
(
jieba
.
cut
(
label
))
rouge
=
Rouge
()
scores
=
rouge
.
get_scores
(
' '
.
join
(
hypothesis
)
,
' '
.
join
(
reference
))
result
=
scores
[
0
]
for
k
,
v
in
result
.
items
():
score_dict
[
k
].
append
(
round
(
v
[
"f"
]
*
100
,
4
))
bleu_score
=
sentence_bleu
([
list
(
label
)],
list
(
pred
),
smoothing_function
=
SmoothingFunction
().
method3
)
score_dict
[
"bleu-4"
].
append
(
round
(
bleu_score
*
100
,
4
))
for
k
,
v
in
score_dict
.
items
():
score_dict
[
k
]
=
float
(
np
.
mean
(
v
))
return
score_dict
# Override the decoding parameters of Seq2SeqTrainer
training_args
.
generation_max_length
=
(
training_args
.
generation_max_length
if
training_args
.
generation_max_length
is
not
None
else
data_args
.
val_max_target_length
)
training_args
.
generation_num_beams
=
(
data_args
.
num_beams
if
data_args
.
num_beams
is
not
None
else
training_args
.
generation_num_beams
)
# Initialize our Trainer
trainer
=
Seq2SeqTrainer
(
model
=
model
,
args
=
training_args
,
train_dataset
=
train_dataset
if
training_args
.
do_train
else
None
,
eval_dataset
=
eval_dataset
if
training_args
.
do_eval
else
None
,
tokenizer
=
tokenizer
,
data_collator
=
data_collator
,
compute_metrics
=
compute_metrics
if
training_args
.
predict_with_generate
else
None
,
save_prefixencoder
=
model_args
.
pre_seq_len
is
not
None
)
# Training
if
training_args
.
do_train
:
checkpoint
=
None
if
training_args
.
resume_from_checkpoint
is
not
None
:
checkpoint
=
training_args
.
resume_from_checkpoint
# elif last_checkpoint is not None:
# checkpoint = last_checkpoint
model
.
gradient_checkpointing_enable
()
model
.
enable_input_require_grads
()
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
)
# trainer.save_model() # Saves the tokenizer too for easy upload
metrics
=
train_result
.
metrics
max_train_samples
=
(
data_args
.
max_train_samples
if
data_args
.
max_train_samples
is
not
None
else
len
(
train_dataset
)
)
metrics
[
"train_samples"
]
=
min
(
max_train_samples
,
len
(
train_dataset
))
trainer
.
log_metrics
(
"train"
,
metrics
)
trainer
.
save_metrics
(
"train"
,
metrics
)
trainer
.
save_state
()
# Evaluation
results
=
{}
max_seq_length
=
data_args
.
max_source_length
+
data_args
.
max_target_length
+
1
if
training_args
.
do_eval
:
logger
.
info
(
"*** Evaluate ***"
)
metrics
=
trainer
.
evaluate
(
metric_key_prefix
=
"eval"
,
do_sample
=
True
,
top_p
=
0.7
,
max_length
=
max_seq_length
,
temperature
=
0.95
)
max_eval_samples
=
data_args
.
max_eval_samples
if
data_args
.
max_eval_samples
is
not
None
else
len
(
eval_dataset
)
metrics
[
"eval_samples"
]
=
min
(
max_eval_samples
,
len
(
eval_dataset
))
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
if
training_args
.
do_predict
:
logger
.
info
(
"*** Predict ***"
)
predict_results
=
trainer
.
predict
(
predict_dataset
,
metric_key_prefix
=
"predict"
,
max_length
=
max_seq_length
,
do_sample
=
True
,
top_p
=
0.7
,
temperature
=
0.95
)
metrics
=
predict_results
.
metrics
max_predict_samples
=
(
data_args
.
max_predict_samples
if
data_args
.
max_predict_samples
is
not
None
else
len
(
predict_dataset
)
)
metrics
[
"predict_samples"
]
=
min
(
max_predict_samples
,
len
(
predict_dataset
))
trainer
.
log_metrics
(
"predict"
,
metrics
)
trainer
.
save_metrics
(
"predict"
,
metrics
)
if
trainer
.
is_world_process_zero
():
if
training_args
.
predict_with_generate
:
predictions
=
tokenizer
.
batch_decode
(
predict_results
.
predictions
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
predictions
=
[
pred
.
strip
()
for
pred
in
predictions
]
labels
=
tokenizer
.
batch_decode
(
predict_results
.
label_ids
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
labels
=
[
label
.
strip
()
for
label
in
labels
]
output_prediction_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"generated_predictions.txt"
)
with
open
(
output_prediction_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
p
,
l
in
zip
(
predictions
,
labels
):
res
=
json
.
dumps
({
"labels"
:
l
,
"predict"
:
p
},
ensure_ascii
=
False
)
writer
.
write
(
f
"
{
res
}
\n
"
)
return
results
def
_mp_fn
(
index
):
# For xla_spawn (TPUs)
main
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment