Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
a9288b49
Unverified
Commit
a9288b49
authored
Jan 19, 2024
by
SangKim
Committed by
GitHub
Jan 19, 2024
Browse files
Modularize InstructPix2Pix SDXL inferencing during and after training in examples (#6569)
parent
c5441965
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
75 additions
and
74 deletions
+75
-74
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+75
-74
No files found.
examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
View file @
a9288b49
...
@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available
...
@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available
from
diffusers.utils.torch_utils
import
is_compiled_module
from
diffusers.utils.torch_utils
import
is_compiled_module
if
is_wandb_available
():
import
wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version
(
"0.26.0.dev0"
)
check_min_version
(
"0.26.0.dev0"
)
...
@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
...
@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING
=
{
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
}
TORCH_DTYPE_MAPPING
=
{
"fp32"
:
torch
.
float32
,
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
}
def
log_validation
(
pipeline
,
args
,
accelerator
,
generator
,
global_step
,
is_final_validation
=
False
,
):
logger
.
info
(
f
"Running validation...
\n
Generating
{
args
.
num_validation_images
}
images with prompt:"
f
"
{
args
.
validation_prompt
}
."
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
val_save_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"validation_images"
)
if
not
os
.
path
.
exists
(
val_save_dir
):
os
.
makedirs
(
val_save_dir
)
original_image
=
(
lambda
image_url_or_path
:
load_image
(
image_url_or_path
)
if
urlparse
(
image_url_or_path
).
scheme
else
Image
.
open
(
image_url_or_path
).
convert
(
"RGB"
)
)(
args
.
val_image_url_or_path
)
with
torch
.
autocast
(
str
(
accelerator
.
device
).
replace
(
":0"
,
""
),
enabled
=
accelerator
.
mixed_precision
==
"fp16"
):
edited_images
=
[]
# Run inference
for
val_img_idx
in
range
(
args
.
num_validation_images
):
a_val_img
=
pipeline
(
args
.
validation_prompt
,
image
=
original_image
,
num_inference_steps
=
20
,
image_guidance_scale
=
1.5
,
guidance_scale
=
7
,
generator
=
generator
,
).
images
[
0
]
edited_images
.
append
(
a_val_img
)
# Save validation images
a_val_img
.
save
(
os
.
path
.
join
(
val_save_dir
,
f
"step_
{
global_step
}
_val_img_
{
val_img_idx
}
.png"
))
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
wandb_table
=
wandb
.
Table
(
columns
=
WANDB_TABLE_COL_NAMES
)
for
edited_image
in
edited_images
:
wandb_table
.
add_data
(
wandb
.
Image
(
original_image
),
wandb
.
Image
(
edited_image
),
args
.
validation_prompt
)
logger_name
=
"test"
if
is_final_validation
else
"validation"
tracker
.
log
({
logger_name
:
wandb_table
})
def
import_model_class_from_model_name_or_path
(
def
import_model_class_from_model_name_or_path
(
pretrained_model_name_or_path
:
str
,
revision
:
str
,
subfolder
:
str
=
"text_encoder"
pretrained_model_name_or_path
:
str
,
revision
:
str
,
subfolder
:
str
=
"text_encoder"
):
):
...
@@ -447,11 +501,6 @@ def main():
...
@@ -447,11 +501,6 @@ def main():
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
generator
=
torch
.
Generator
(
device
=
accelerator
.
device
).
manual_seed
(
args
.
seed
)
if
args
.
report_to
==
"wandb"
:
if
not
is_wandb_available
():
raise
ImportError
(
"Make sure to install wandb if you want to use it for logging during training."
)
import
wandb
# Make one log on every process with the configuration for debugging.
# Make one log on every process with the configuration for debugging.
logging
.
basicConfig
(
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
...
@@ -1111,11 +1160,6 @@ def main():
...
@@ -1111,11 +1160,6 @@ def main():
### BEGIN: Perform validation every `validation_epochs` steps
### BEGIN: Perform validation every `validation_epochs` steps
if
global_step
%
args
.
validation_steps
==
0
:
if
global_step
%
args
.
validation_steps
==
0
:
if
(
args
.
val_image_url_or_path
is
not
None
)
and
(
args
.
validation_prompt
is
not
None
):
if
(
args
.
val_image_url_or_path
is
not
None
)
and
(
args
.
validation_prompt
is
not
None
):
logger
.
info
(
f
"Running validation...
\n
Generating
{
args
.
num_validation_images
}
images with prompt:"
f
"
{
args
.
validation_prompt
}
."
)
# create pipeline
# create pipeline
if
args
.
use_ema
:
if
args
.
use_ema
:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
...
@@ -1135,44 +1179,16 @@ def main():
...
@@ -1135,44 +1179,16 @@ def main():
variant
=
args
.
variant
,
variant
=
args
.
variant
,
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
log_validation
(
pipeline
,
# run inference
args
,
# Save validation images
accelerator
,
val_save_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"validation_images"
)
generator
,
if
not
os
.
path
.
exists
(
val_save_dir
):
global_step
,
os
.
makedirs
(
val_save_dir
)
is_final_validation
=
False
,
)
original_image
=
(
lambda
image_url_or_path
:
load_image
(
image_url_or_path
)
if
urlparse
(
image_url_or_path
).
scheme
else
Image
.
open
(
image_url_or_path
).
convert
(
"RGB"
)
)(
args
.
val_image_url_or_path
)
with
torch
.
autocast
(
str
(
accelerator
.
device
).
replace
(
":0"
,
""
),
enabled
=
accelerator
.
mixed_precision
==
"fp16"
):
edited_images
=
[]
for
val_img_idx
in
range
(
args
.
num_validation_images
):
a_val_img
=
pipeline
(
args
.
validation_prompt
,
image
=
original_image
,
num_inference_steps
=
20
,
image_guidance_scale
=
1.5
,
guidance_scale
=
7
,
generator
=
generator
,
).
images
[
0
]
edited_images
.
append
(
a_val_img
)
a_val_img
.
save
(
os
.
path
.
join
(
val_save_dir
,
f
"step_
{
global_step
}
_val_img_
{
val_img_idx
}
.png"
))
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
wandb_table
=
wandb
.
Table
(
columns
=
WANDB_TABLE_COL_NAMES
)
for
edited_image
in
edited_images
:
wandb_table
.
add_data
(
wandb
.
Image
(
original_image
),
wandb
.
Image
(
edited_image
),
args
.
validation_prompt
)
tracker
.
log
({
"validation"
:
wandb_table
})
if
args
.
use_ema
:
if
args
.
use_ema
:
# Switch back to the original UNet parameters.
# Switch back to the original UNet parameters.
ema_unet
.
restore
(
unet
.
parameters
())
ema_unet
.
restore
(
unet
.
parameters
())
...
@@ -1187,7 +1203,6 @@ def main():
...
@@ -1187,7 +1203,6 @@ def main():
# Create the pipeline using the trained modules and save it.
# Create the pipeline using the trained modules and save it.
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
unet
=
unwrap_model
(
unet
)
if
args
.
use_ema
:
if
args
.
use_ema
:
ema_unet
.
copy_to
(
unet
.
parameters
())
ema_unet
.
copy_to
(
unet
.
parameters
())
...
@@ -1198,10 +1213,11 @@ def main():
...
@@ -1198,10 +1213,11 @@ def main():
tokenizer
=
tokenizer_1
,
tokenizer
=
tokenizer_1
,
tokenizer_2
=
tokenizer_2
,
tokenizer_2
=
tokenizer_2
,
vae
=
vae
,
vae
=
vae
,
unet
=
unet
,
unet
=
unwrap_model
(
unet
)
,
revision
=
args
.
revision
,
revision
=
args
.
revision
,
variant
=
args
.
variant
,
variant
=
args
.
variant
,
)
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
pipeline
.
save_pretrained
(
args
.
output_dir
)
if
args
.
push_to_hub
:
if
args
.
push_to_hub
:
...
@@ -1212,30 +1228,15 @@ def main():
...
@@ -1212,30 +1228,15 @@ def main():
ignore_patterns
=
[
"step_*"
,
"epoch_*"
],
ignore_patterns
=
[
"step_*"
,
"epoch_*"
],
)
)
if
args
.
validation_prompt
is
not
None
:
if
(
args
.
val_image_url_or_path
is
not
None
)
and
(
args
.
validation_prompt
is
not
None
):
edited_images
=
[]
log_validation
(
pipeline
=
pipeline
.
to
(
accelerator
.
device
)
pipeline
,
with
torch
.
autocast
(
str
(
accelerator
.
device
).
replace
(
":0"
,
""
)):
args
,
for
_
in
range
(
args
.
num_validation_images
):
accelerator
,
edited_images
.
append
(
generator
,
pipeline
(
global_step
,
args
.
validation_prompt
,
is_final_validation
=
True
,
image
=
original_image
,
)
num_inference_steps
=
20
,
image_guidance_scale
=
1.5
,
guidance_scale
=
7
,
generator
=
generator
,
).
images
[
0
]
)
for
tracker
in
accelerator
.
trackers
:
if
tracker
.
name
==
"wandb"
:
wandb_table
=
wandb
.
Table
(
columns
=
WANDB_TABLE_COL_NAMES
)
for
edited_image
in
edited_images
:
wandb_table
.
add_data
(
wandb
.
Image
(
original_image
),
wandb
.
Image
(
edited_image
),
args
.
validation_prompt
)
tracker
.
log
({
"test"
:
wandb_table
})
accelerator
.
end_training
()
accelerator
.
end_training
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment