Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
UMT5_pytorch
Commits
012579df
Commit
012579df
authored
Aug 22, 2024
by
wanglch
Browse files
Delete multi_dcu_train.py
parent
512350b9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
209 deletions
+0
-209
multi_dcu_train.py
multi_dcu_train.py
+0
-209
No files found.
multi_dcu_train.py
deleted
100644 → 0
View file @
512350b9
import
torch
from
torch.utils.data
import
Dataset
,
DataLoader
from
transformers
import
AutoTokenizer
,
AutoModelForSeq2SeqLM
from
transformers
import
AdamW
,
get_scheduler
from
tqdm.auto
import
tqdm
from
rouge
import
Rouge
import
random
import
numpy
as
np
import
os
import
json
from
torch
import
nn
from
datetime
import
datetime
def
seed_everything
(
seed
=
1029
):
random
.
seed
(
seed
)
os
.
environ
[
'PYTHONHASHSEED'
]
=
str
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
class
LCSTS
(
Dataset
):
def
__init__
(
self
,
data_file
):
self
.
data
=
self
.
load_data
(
data_file
)
def
load_data
(
self
,
data_file
):
Data
=
{}
with
open
(
data_file
,
'rt'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
idx
>=
max_dataset_size
:
break
items
=
line
.
strip
().
split
(
'!=!'
)
assert
len
(
items
)
==
2
Data
[
idx
]
=
{
'title'
:
items
[
0
],
'content'
:
items
[
1
]
}
return
Data
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
idx
):
return
self
.
data
[
idx
]
def
collate_fn
(
batch_samples
):
batch_inputs
,
batch_targets
=
[],
[]
for
sample
in
batch_samples
:
batch_inputs
.
append
(
sample
[
'content'
])
batch_targets
.
append
(
sample
[
'title'
])
batch_data
=
tokenizer
(
batch_inputs
,
padding
=
True
,
max_length
=
max_input_length
,
truncation
=
True
,
return_tensors
=
"pt"
)
with
tokenizer
.
as_target_tokenizer
():
labels
=
tokenizer
(
batch_targets
,
padding
=
True
,
max_length
=
max_target_length
,
truncation
=
True
,
return_tensors
=
"pt"
)[
"input_ids"
]
batch_data
[
'decoder_input_ids'
]
=
model
.
module
.
prepare_decoder_input_ids_from_labels
(
labels
)
end_token_index
=
torch
.
where
(
labels
==
tokenizer
.
eos_token_id
)[
1
]
for
idx
,
end_idx
in
enumerate
(
end_token_index
):
labels
[
idx
][
end_idx
+
1
:]
=
-
100
batch_data
[
'labels'
]
=
labels
return
batch_data
def
train_loop
(
dataloader
,
model
,
optimizer
,
lr_scheduler
,
epoch
,
total_loss
):
progress_bar
=
tqdm
(
range
(
len
(
dataloader
)))
progress_bar
.
set_description
(
f
'loss:
{
0
:
>
7
f
}
'
)
finish_batch_num
=
(
epoch
-
1
)
*
len
(
dataloader
)
model
.
train
()
for
batch
,
batch_data
in
enumerate
(
dataloader
,
start
=
1
):
batch_data
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
batch_data
.
items
()}
outputs
=
model
(
**
batch_data
)
loss
=
outputs
.
loss
loss
=
loss
.
mean
()
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
lr_scheduler
.
step
()
total_loss
+=
loss
.
item
()
progress_bar
.
set_description
(
f
'loss:
{
total_loss
/
(
finish_batch_num
+
batch
):
>
7
f
}
'
)
progress_bar
.
update
(
1
)
return
total_loss
def
test_loop
(
dataloader
,
model
):
preds
,
labels
=
[],
[]
model
.
eval
()
for
batch_data
in
tqdm
(
dataloader
):
batch_data
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
batch_data
.
items
()}
with
torch
.
no_grad
():
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
generated_tokens
=
model
.
module
.
generate
(
batch_data
[
"input_ids"
],
attention_mask
=
batch_data
[
"attention_mask"
],
max_length
=
max_target_length
,
num_beams
=
beam_size
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
).
cpu
().
numpy
()
if
isinstance
(
generated_tokens
,
tuple
):
generated_tokens
=
generated_tokens
[
0
]
label_tokens
=
batch_data
[
"labels"
].
cpu
().
numpy
()
decoded_preds
=
tokenizer
.
batch_decode
(
generated_tokens
,
skip_special_tokens
=
True
)
label_tokens
=
np
.
where
(
label_tokens
!=
-
100
,
label_tokens
,
tokenizer
.
pad_token_id
)
decoded_labels
=
tokenizer
.
batch_decode
(
label_tokens
,
skip_special_tokens
=
True
)
preds
+=
[
' '
.
join
(
pred
.
strip
())
for
pred
in
decoded_preds
]
labels
+=
[
' '
.
join
(
label
.
strip
())
for
label
in
decoded_labels
]
scores
=
rouge
.
get_scores
(
hyps
=
preds
,
refs
=
labels
,
avg
=
True
)
result
=
{
key
:
value
[
'f'
]
*
100
for
key
,
value
in
scores
.
items
()}
result
[
'avg'
]
=
np
.
mean
(
list
(
result
.
values
()))
print
(
f
"Rouge1:
{
result
[
'rouge-1'
]:
>
0.2
f
}
Rouge2:
{
result
[
'rouge-2'
]:
>
0.2
f
}
RougeL:
{
result
[
'rouge-l'
]:
>
0.2
f
}
\n
"
)
return
result
if
__name__
==
'__main__'
:
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
=
"4,5"
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
print
(
f
'Using
{
device
}
device'
)
seed_everything
(
5
)
rouge
=
Rouge
()
max_dataset_size
=
200000
max_input_length
=
512
max_target_length
=
32
batch_size
=
16
learning_rate
=
1e-5
epoch_num
=
1
beam_size
=
4
no_repeat_ngram_size
=
2
folder_path
=
"./saves/train_dtk_weights"
# 检查文件夹是否存在
if
not
os
.
path
.
exists
(
folder_path
):
# 如果不存在,则创建文件夹
os
.
makedirs
(
folder_path
)
train_data
=
LCSTS
(
'./data/lcsts_tsv/data1.tsv'
)
valid_data
=
LCSTS
(
'./data/lcsts_tsv/data2.tsv'
)
model_checkpoint
=
"./umt5_base"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_checkpoint
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_checkpoint
)
# 检查是否有多个 GPU 可用
if
torch
.
cuda
.
device_count
()
>
1
:
print
(
"Let's use"
,
torch
.
cuda
.
device_count
(),
"GPUs!"
)
# 如果有多个 GPUs,使用 nn.DataParallel 包装模型
model
=
nn
.
DataParallel
(
model
).
to
(
device
)
train_dataloader
=
DataLoader
(
train_data
,
batch_size
=
batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
)
valid_dataloader
=
DataLoader
(
valid_data
,
batch_size
=
batch_size
,
shuffle
=
False
,
collate_fn
=
collate_fn
)
optimizer
=
AdamW
(
model
.
parameters
(),
lr
=
learning_rate
)
lr_scheduler
=
get_scheduler
(
"linear"
,
optimizer
=
optimizer
,
num_warmup_steps
=
0
,
num_training_steps
=
epoch_num
*
len
(
train_dataloader
),
)
total_loss
=
0.
best_avg_rouge
=
0.
for
t
in
range
(
epoch_num
):
print
(
f
"Epoch
{
t
+
1
}
/
{
epoch_num
}
\n
-------------------------------"
)
total_loss
=
train_loop
(
train_dataloader
,
model
,
optimizer
,
lr_scheduler
,
t
+
1
,
total_loss
)
valid_rouge
=
test_loop
(
valid_dataloader
,
model
)
rouge_avg
=
valid_rouge
[
'avg'
]
if
rouge_avg
>
best_avg_rouge
:
best_avg_rouge
=
rouge_avg
print
(
'saving new weights...
\n
'
)
weight_path
=
f
'./saves/train_dtk_weights/epoch_
{
t
+
1
}
_valid_rouge_
{
rouge_avg
:
0.4
f
}
_model_dtk_weights.bin'
torch
.
save
(
model
.
state_dict
(),
weight_path
)
# 加载训练后的权重
state_dict
=
torch
.
load
(
weight_path
)
model
.
load_state_dict
(
state_dict
)
# 获取当前的日期和时间
now
=
datetime
.
now
()
timestamp
=
now
.
strftime
(
"%Y%m%d_%H%M%S"
)
new_model_path
=
f
'./saves/umt5_
{
timestamp
}
'
model
.
module
.
save_pretrained
(
new_model_path
)
tokenizer
.
save_pretrained
(
new_model_path
)
print
(
"Done!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment