Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1b1d6444
"docs/vscode:/vscode.git/clone" did not exist on "95ddbc19fe68ac355fbc0bc393062b1d44a080ae"
Unverified
Commit
1b1d6444
authored
Sep 01, 2022
by
Suraj Patil
Committed by
GitHub
Sep 01, 2022
Browse files
[train_unconditional] fix gradient accumulation. (#308)
fix grad accum
parent
47242509
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
3 deletions
+10
-3
examples/unconditional_image_generation/train_unconditional.py
...les/unconditional_image_generation/train_unconditional.py
+10
-3
No files found.
examples/unconditional_image_generation/train_unconditional.py
View file @
1b1d6444
import
argparse
import
math
import
os
import
torch
...
...
@@ -29,6 +30,7 @@ logger = get_logger(__name__)
def
main
(
args
):
logging_dir
=
os
.
path
.
join
(
args
.
output_dir
,
args
.
logging_dir
)
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
args
.
gradient_accumulation_steps
,
mixed_precision
=
args
.
mixed_precision
,
log_with
=
"tensorboard"
,
logging_dir
=
logging_dir
,
...
...
@@ -105,6 +107,8 @@ def main(args):
model
,
optimizer
,
train_dataloader
,
lr_scheduler
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
ema_model
=
EMAModel
(
model
,
inv_gamma
=
args
.
ema_inv_gamma
,
power
=
args
.
ema_power
,
max_value
=
args
.
ema_max_decay
)
if
args
.
push_to_hub
:
...
...
@@ -117,7 +121,7 @@ def main(args):
global_step
=
0
for
epoch
in
range
(
args
.
num_epochs
):
model
.
train
()
progress_bar
=
tqdm
(
total
=
len
(
train_dataloader
)
,
disable
=
not
accelerator
.
is_local_main_process
)
progress_bar
=
tqdm
(
total
=
num_update_steps_per_epoch
,
disable
=
not
accelerator
.
is_local_main_process
)
progress_bar
.
set_description
(
f
"Epoch
{
epoch
}
"
)
for
step
,
batch
in
enumerate
(
train_dataloader
):
clean_images
=
batch
[
"input"
]
...
...
@@ -146,13 +150,16 @@ def main(args):
ema_model
.
step
(
model
)
optimizer
.
zero_grad
()
progress_bar
.
update
(
1
)
# Checks if the accelerator has performed an optimization step behind the scenes
if
accelerator
.
sync_gradients
:
progress_bar
.
update
(
1
)
global_step
+=
1
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
lr_scheduler
.
get_last_lr
()[
0
],
"step"
:
global_step
}
if
args
.
use_ema
:
logs
[
"ema_decay"
]
=
ema_model
.
decay
progress_bar
.
set_postfix
(
**
logs
)
accelerator
.
log
(
logs
,
step
=
global_step
)
global_step
+=
1
progress_bar
.
close
()
accelerator
.
wait_for_everyone
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment