Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
UMT5_pytorch
Commits
a5866d29
Commit
a5866d29
authored
Aug 22, 2024
by
wanglch
Browse files
Update train_single_dcu.py
parent
dd221315
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
train_single_dcu.py
train_single_dcu.py
+5
-5
No files found.
train_single_dcu.py
View file @
a5866d29
...
...
@@ -149,10 +149,10 @@ if __name__=='__main__':
beam_size
=
4
no_repeat_ngram_size
=
2
train_data
=
LCSTS
(
'
/home/wanglch/projects
/Umt5/data/lcsts_tsv/data1.tsv'
)
valid_data
=
LCSTS
(
'
/home/wanglch/projects
/Umt5/data/lcsts_tsv/data2.tsv'
)
train_data
=
LCSTS
(
'
..
/Umt5/data/lcsts_tsv/data1.tsv'
)
valid_data
=
LCSTS
(
'
..
/Umt5/data/lcsts_tsv/data2.tsv'
)
model_checkpoint
=
"
/home/wanglch/projects
/Umt5/umt5_base"
model_checkpoint
=
"
..
/Umt5/umt5_base"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_checkpoint
,
trust_remote_code
=
True
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_checkpoint
,
trust_remote_code
=
True
)
model
=
model
.
to
(
device
)
...
...
@@ -180,7 +180,7 @@ if __name__=='__main__':
if
rouge_avg
>
best_avg_rouge
:
best_avg_rouge
=
rouge_avg
print
(
'saving new weights...
\n
'
)
weight_path
=
f
'
/home/wanglch/projects
/saves/utm5/train_dtk_weights/epoch_
{
t
+
1
}
_valid_rouge_
{
rouge_avg
:
0.4
f
}
_model_dtk_weights.bin'
weight_path
=
f
'
..
/saves/utm5/train_dtk_weights/epoch_
{
t
+
1
}
_valid_rouge_
{
rouge_avg
:
0.4
f
}
_model_dtk_weights.bin'
torch
.
save
(
model
.
state_dict
(),
weight_path
)
# 加载训练后的权重
state_dict
=
torch
.
load
(
weight_path
)
...
...
@@ -188,7 +188,7 @@ if __name__=='__main__':
# 获取当前的日期和时间
now
=
datetime
.
now
()
timestamp
=
now
.
strftime
(
"%Y%m%d_%H%M%S"
)
new_model_path
=
f
'
/home/wanglch/projects
/saves/utm5/train_dtk_weights/umt5_
{
timestamp
}
'
new_model_path
=
f
'
..
/saves/utm5/train_dtk_weights/umt5_
{
timestamp
}
'
model
.
save_pretrained
(
new_model_path
)
tokenizer
.
save_pretrained
(
new_model_path
)
print
(
"Done!"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment