Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
UMT5_pytorch
Commits
57fe5c63
Commit
57fe5c63
authored
May 21, 2024
by
wanglch
Browse files
Update multi_dcu_train.py
parent
3d7eb65b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
multi_dcu_train.py
multi_dcu_train.py
+2
-2
No files found.
multi_dcu_train.py
View file @
57fe5c63
...
@@ -81,7 +81,7 @@ def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
...
@@ -81,7 +81,7 @@ def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
model
.
train
()
model
.
train
()
for
batch
,
batch_data
in
enumerate
(
dataloader
,
start
=
1
):
for
batch
,
batch_data
in
enumerate
(
dataloader
,
start
=
1
):
batch_data
=
batch_data
.
to
(
device
)
batch_data
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
batch_data
.
items
()}
outputs
=
model
(
**
batch_data
)
outputs
=
model
(
**
batch_data
)
loss
=
outputs
.
loss
loss
=
outputs
.
loss
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
...
@@ -101,7 +101,7 @@ def test_loop(dataloader, model):
...
@@ -101,7 +101,7 @@ def test_loop(dataloader, model):
model
.
eval
()
model
.
eval
()
for
batch_data
in
tqdm
(
dataloader
):
for
batch_data
in
tqdm
(
dataloader
):
batch_data
=
batch_data
.
to
(
device
)
batch_data
=
{
k
:
v
.
to
(
device
)
for
k
,
v
in
batch_data
.
items
()}
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
generated_tokens
=
model
.
module
.
generate
(
generated_tokens
=
model
.
module
.
generate
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment