Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
a9288b49
Unverified
Commit
a9288b49
authored
Jan 19, 2024
by
SangKim
Committed by
GitHub
Jan 19, 2024
Browse files
Modularize InstructPix2Pix SDXL inferencing during and after training in examples (#6569)
parent
c5441965
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
75 additions
and
74 deletions
+75
-74
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+75
-74
No files found.
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
View file @
a9288b49
...
@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available
...
@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
from
diffusers.utils.torch_utils
import
is_compiled_module
if
is_wandb_available
():
import
wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version
(
"0.26.0.dev0"
)
check_min_version
(
"0.26.0.dev0"
)
...
@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
...
@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING
=
{
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
}
TORCH_DTYPE_MAPPING
=
{
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
}
def
log_validation
(
pipeline
,
args
,
accelerator
,
generator
,
global_step
,
is_final_validation
=
False
,
):
logger
.
info
(
f
"Running validation...
\n
Generating
{
args
.
num_validation_images
}
images with prompt:"
f
"
{
args
.
validation_prompt
}
."
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
val_save_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"validation_images"
)
if
not
os
.
path
.
exists
(
val_save_dir
):
os
.
makedirs
(
val_save_dir
)
original_image
=
(
lambda
image_url_or_path
:
load_image
(
image_url_or_path
)
if
urlparse
(
image_url_or_path
).
scheme
else
Image
.
open
(
image_url_or_path
).
convert
(
"RGB"
)
)(
args
.
val_image_url_or_path
)
with
torch
.
autocast
(
str
(
accelerator
.
device
).
replace
(
":0"
,
""
),
enabled
=
accelerator
.
mixed_precision
==
"fp16"
):
edited_images
=
[]
# Run inference
for
val_img_idx
in
range
(
args
.
num_validation_images
):
a_val_img
=
pipeline
(
args
.
validation_prompt
,
image
=
original_image
,
num_inference_steps
=
20
,
image_guidance_scale
=
1.5
,
guidance_scale
=
7
,
generator
=
generator
,
).
images
[
0
]
edited_images
.
append
(
a_val_img
)
# Save validation images
a_val_img
.
save
(
os
.
path
.
join
(
val_save_dir
,
f
"step_
{
global_step
}
_val_img_
{
val_img_idx
}
.png"
))
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
wandb_table
=
wandb
.
Table
(
columns
=
WANDB_TABLE_COL_NAMES
)
for
edited_image
in
edited_images
:
wandb_table
.
add_data
(
wandb
.
Image
(
original_image
),
wandb
.
Image
(
edited_image
),
args
.
validation_prompt
)
logger_name
=
"test"
if
is_final_validation
else
"validation"
tracker
.
log
({
logger_name
:
wandb_table
})
def
import_model_class_from_model_name_or_path
(
def
import_model_class_from_model_name_or_path
(
pretrained_model_name_or_path
:
str
,
revision
:
str
,
subfolder
:
str
=
"text_encoder"
pretrained_model_name_or_path
:
str
,
revision
:
str
,
subfolder
:
str
=
"text_encoder"
):
):
...
@@ -447,11 +501,6 @@ def main():
...
@@ -447,11 +501,6 @@ def main():
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
if
args
.
report_to
==
"wandb"
:
if
not
is_wandb_available
():
raise
ImportError
(
"Make sure to install wandb if you want to use it for logging during training."
)
import
wandb
# Make one log on every process with the configuration for debugging.
# Make one log on every process with the configuration for debugging.
logging
.
basicConfig
(
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
...
@@ -1111,11 +1160,6 @@ def main():
...
@@ -1111,11 +1160,6 @@ def main():
### BEGIN: Perform validation every `validation_epochs` steps
### BEGIN: Perform validation every `validation_epochs` steps
if
global_step
%
args
.
validation_steps
==
0
:
if
global_step
%
args
.
validation_steps
==
0
:
if
(
args
.
val_image_url_or_path
is
not
None
)
and
(
args
.
validation_prompt
is
not
None
):
if
(
args
.
val_image_url_or_path
is
not
None
)
and
(
args
.
validation_prompt
is
not
None
):
logger
.
info
(
f
"Running validation...
\n
Generating
{
args
.
num_validation_images
}
images with prompt:"
f
"
{
args
.
validation_prompt
}
."
)
# create pipeline
# create pipeline
if
args
.
use_ema
:
if
args
.
use_ema
:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
...
@@ -1135,44 +1179,16 @@ def main():
...
@@ -1135,44 +1179,16 @@ def main():
variant
=
args
.
variant
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
log_validation
(
pipeline
,
# run inference
args
,
# Save validation images
accelerator
,
val_save_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"validation_images"
)
generator
,
if
not
os
.
path
.
exists
(
val_save_dir
):
global_step
,
os
.
makedirs
(
val_save_dir
)
is_final_validation
=
False
,
)
original_image
=
(
lambda
image_url_or_path
:
load_image
(
image_url_or_path
)
if
urlparse
(
image_url_or_path
).
scheme
else
Image
.
open
(
image_url_or_path
).
convert
(
"RGB"
)
)(
args
.
val_image_url_or_path
)
with
torch
.
autocast
(
str
(
accelerator
.
device
).
replace
(
":0"
,
""
),
enabled
=
accelerator
.
mixed_precision
==
"fp16"
):
edited_images
=
[]
for
val_img_idx
in
range
(
args
.
num_validation_images
):
a_val_img
=
pipeline
(
args
.
validation_prompt
,
image
=
original_image
,
num_inference_steps
=
20
,
image_guidance_scale
=
1.5
,
guidance_scale
=
7
,
generator
=
generator
,
).
images
[
0
]
edited_images
.
append
(
a_val_img
)
a_val_img
.
save
(
os
.
path
.
join
(
val_save_dir
,
f
"step_
{
global_step
}
_val_img_
{
val_img_idx
}
.png"
))
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
wandb_table
=
wandb
.
Table
(
columns
=
WANDB_TABLE_COL_NAMES
)
for
edited_image
in
edited_images
:
wandb_table
.
add_data
(
wandb
.
Image
(
original_image
),
wandb
.
Image
(
edited_image
),
args
.
validation_prompt
)
tracker
.
log
({
"validation"
:
wandb_table
})
if
args
.
use_ema
:
if
args
.
use_ema
:
# Switch back to the original UNet parameters.
# Switch back to the original UNet parameters.
ema_unet
.
restore
(
unet
.
parameters
())
ema_unet
.
restore
(
unet
.
parameters
())
...
@@ -1187,7 +1203,6 @@ def main():
...
@@ -1187,7 +1203,6 @@ def main():
# Create the pipeline using the trained modules and save it.
# Create the pipeline using the trained modules and save it.
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
unwrap_model
(
unet
)
if
args
.
use_ema
:
if
args
.
use_ema
:
ema_unet
.
copy_to
(
unet
.
parameters
())
ema_unet
.
copy_to
(
unet
.
parameters
())
...
@@ -1198,10 +1213,11 @@ def main():
...
@@ -1198,10 +1213,11 @@ def main():
tokenizer
=
tokenizer_1
,
tokenizer
=
tokenizer_1
,
tokenizer_2
=
tokenizer_2
,
tokenizer_2
=
tokenizer_2
,
vae
=
vae
,
vae
=
vae
,
unet
=
unet
,
unet
=
unwrap_model
(
unet
)
,
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
variant
=
args
.
variant
,
)
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
...
@@ -1212,30 +1228,15 @@ def main():
...
@@ -1212,30 +1228,15 @@ def main():
ignore_patterns
=
[
"step_*"
,
"epoch_*"
],
ignore_patterns
=
[
"step_*"
,
"epoch_*"
],
)
)
if
args
.
validation_prompt
is
not
None
:
if
(
args
.
val_image_url_or_path
is
not
None
)
and
(
args
.
validation_prompt
is
not
None
):
edited_images
=
[]
log_validation
(
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
,
with
torch
.
autocast
(
str
(
accelerator
.
device
).
replace
(
":0"
,
""
)):
args
,
for
_
in
range
(
args
.
num_validation_images
):
accelerator
,
edited_images
.
append
(
generator
,
pipeline
(
global_step
,
args
.
validation_prompt
,
is_final_validation
=
True
,
image
=
original_image
,
)
num_inference_steps
=
20
,
image_guidance_scale
=
1.5
,
guidance_scale
=
7
,
generator
=
generator
,
).
images
[
0
]
)
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
wandb_table
=
wandb
.
Table
(
columns
=
WANDB_TABLE_COL_NAMES
)
for
edited_image
in
edited_images
:
wandb_table
.
add_data
(
wandb
.
Image
(
original_image
),
wandb
.
Image
(
edited_image
),
args
.
validation_prompt
)
tracker
.
log
({
"test"
:
wandb_table
})
accelerator
.
end_training
()
accelerator
.
end_training
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment