Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
ebd44957
Unverified
Commit
ebd44957
authored
Mar 14, 2023
by
Will Berman
Committed by
GitHub
Mar 14, 2023
Browse files
image generation main process checks (#2631)
parent
e2d9a9be
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
24 deletions
+25
-24
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+5
-4
examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
...ojects/mulit_token_textual_inversion/textual_inversion.py
+1
-1
examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
...ojects/onnxruntime/textual_inversion/textual_inversion.py
+1
-1
examples/text_to_image/train_text_to_image_lora.py
examples/text_to_image/train_text_to_image_lora.py
+13
-14
examples/textual_inversion/textual_inversion.py
examples/textual_inversion/textual_inversion.py
+5
-4
No files found.
examples/dreambooth/train_dreambooth.py
View file @
ebd44957
...
...
@@ -1000,13 +1000,14 @@ def main(args):
progress_bar
.
update
(
1
)
global_step
+=
1
if
global_step
%
args
.
checkpointing_steps
==
0
:
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
if
global_step
%
args
.
checkpointing_steps
==
0
:
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
accelerator
.
save_state
(
save_path
)
logger
.
info
(
f
"Saved state to
{
save_path
}
"
)
if
args
.
validation_prompt
is
not
None
and
global_step
%
args
.
validation_steps
==
0
:
log_validation
(
text_encoder
,
tokenizer
,
unet
,
vae
,
args
,
accelerator
,
weight_dtype
,
epoch
)
if
args
.
validation_prompt
is
not
None
and
global_step
%
args
.
validation_steps
==
0
:
log_validation
(
text_encoder
,
tokenizer
,
unet
,
vae
,
args
,
accelerator
,
weight_dtype
,
epoch
)
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
lr_scheduler
.
get_last_lr
()[
0
]}
progress_bar
.
set_postfix
(
**
logs
)
...
...
examples/research_projects/mulit_token_textual_inversion/textual_inversion.py
View file @
ebd44957
...
...
@@ -864,7 +864,7 @@ def main():
if
global_step
>=
args
.
max_train_steps
:
break
if
args
.
validation_prompt
is
not
None
and
epoch
%
args
.
validation_epochs
==
0
:
if
accelerator
.
is_main_process
and
args
.
validation_prompt
is
not
None
and
epoch
%
args
.
validation_epochs
==
0
:
logger
.
info
(
f
"Running validation...
\n
Generating
{
args
.
num_validation_images
}
images with prompt:"
f
"
{
args
.
validation_prompt
}
."
...
...
examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
View file @
ebd44957
...
...
@@ -790,7 +790,7 @@ def main():
if
global_step
>=
args
.
max_train_steps
:
break
if
args
.
validation_prompt
is
not
None
and
epoch
%
args
.
validation_epochs
==
0
:
if
accelerator
.
is_main_process
and
args
.
validation_prompt
is
not
None
and
epoch
%
args
.
validation_epochs
==
0
:
logger
.
info
(
f
"Running validation...
\n
Generating
{
args
.
num_validation_images
}
images with prompt:"
f
"
{
args
.
validation_prompt
}
."
...
...
examples/text_to_image/train_text_to_image_lora.py
View file @
ebd44957
...
...
@@ -800,20 +800,19 @@ def main():
pipeline
(
args
.
validation_prompt
,
num_inference_steps
=
30
,
generator
=
generator
).
images
[
0
]
)
if
accelerator
.
is_main_process
:
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"tensorboard"
:
np_images
=
np
.
stack
([
np
.
asarray
(
img
)
for
img
in
images
])
tracker
.
writer
.
add_images
(
"validation"
,
np_images
,
epoch
,
dataformats
=
"NHWC"
)
if
tracker
.
name
==
"wandb"
:
tracker
.
log
(
{
"validation"
:
[
wandb
.
Image
(
image
,
caption
=
f
"
{
i
}
:
{
args
.
validation_prompt
}
"
)
for
i
,
image
in
enumerate
(
images
)
]
}
)
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"tensorboard"
:
np_images
=
np
.
stack
([
np
.
asarray
(
img
)
for
img
in
images
])
tracker
.
writer
.
add_images
(
"validation"
,
np_images
,
epoch
,
dataformats
=
"NHWC"
)
if
tracker
.
name
==
"wandb"
:
tracker
.
log
(
{
"validation"
:
[
wandb
.
Image
(
image
,
caption
=
f
"
{
i
}
:
{
args
.
validation_prompt
}
"
)
for
i
,
image
in
enumerate
(
images
)
]
}
)
del
pipeline
torch
.
cuda
.
empty_cache
()
...
...
examples/textual_inversion/textual_inversion.py
View file @
ebd44957
...
...
@@ -843,13 +843,14 @@ def main():
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"learned_embeds-steps-
{
global_step
}
.bin"
)
save_progress
(
text_encoder
,
placeholder_token_id
,
accelerator
,
args
,
save_path
)
if
global_step
%
args
.
checkpointing_steps
==
0
:
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
if
global_step
%
args
.
checkpointing_steps
==
0
:
save_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
"checkpoint-
{
global_step
}
"
)
accelerator
.
save_state
(
save_path
)
logger
.
info
(
f
"Saved state to
{
save_path
}
"
)
if
args
.
validation_prompt
is
not
None
and
global_step
%
args
.
validation_steps
==
0
:
log_validation
(
text_encoder
,
tokenizer
,
unet
,
vae
,
args
,
accelerator
,
weight_dtype
,
epoch
)
if
args
.
validation_prompt
is
not
None
and
global_step
%
args
.
validation_steps
==
0
:
log_validation
(
text_encoder
,
tokenizer
,
unet
,
vae
,
args
,
accelerator
,
weight_dtype
,
epoch
)
logs
=
{
"loss"
:
loss
.
detach
().
item
(),
"lr"
:
lr_scheduler
.
get_last_lr
()[
0
]}
progress_bar
.
set_postfix
(
**
logs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment