Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
UMT5_pytorch
Commits
dcb1010d
Commit
dcb1010d
authored
May 22, 2024
by
wanglch
Browse files
Update single_dcu_train.py
parent
938ed894
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
single_dcu_train.py
single_dcu_train.py
+4
-4
No files found.
single_dcu_train.py
View file @
dcb1010d
...
...
@@ -156,10 +156,10 @@ if __name__=='__main__':
# 如果不存在,则创建文件夹
os
.
makedirs
(
folder_path
)
train_data
=
LCSTS
(
'
/umt5
/data/lcsts_tsv/data1.tsv'
)
valid_data
=
LCSTS
(
'
/umt5
/data/lcsts_tsv/data2.tsv'
)
train_data
=
LCSTS
(
'
.
/data/lcsts_tsv/data1.tsv'
)
valid_data
=
LCSTS
(
'
.
/data/lcsts_tsv/data2.tsv'
)
model_checkpoint
=
"
/umt5
/umt5_base"
model_checkpoint
=
"
.
/umt5_base"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_checkpoint
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_checkpoint
)
...
...
@@ -194,7 +194,7 @@ if __name__=='__main__':
if
rouge_avg
>
best_avg_rouge
:
best_avg_rouge
=
rouge_avg
print
(
'saving new weights...
\n
'
)
weight_path
=
f
'
/utm5
/saves/train_dtk_weights/epoch_
{
t
+
1
}
_valid_rouge_
{
rouge_avg
:
0.4
f
}
_model_dtk_weights.bin'
weight_path
=
f
'
.
/saves/train_dtk_weights/epoch_
{
t
+
1
}
_valid_rouge_
{
rouge_avg
:
0.4
f
}
_model_dtk_weights.bin'
torch
.
save
(
model
.
state_dict
(),
weight_path
)
# 加载训练后的权重
state_dict
=
torch
.
load
(
weight_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment