Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
b0ffe922
Unverified
Commit
b0ffe922
authored
Oct 21, 2024
by
Yu Zheng
Committed by
GitHub
Oct 22, 2024
Browse files
Update sd3 controlnet example (#9735)
* use make_image_grid in diffusers.utils * use checkpoint on the Hub
parent
1b64772b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
14 deletions
+3
-14
examples/controlnet/README_sd3.md
examples/controlnet/README_sd3.md
+1
-1
examples/controlnet/train_controlnet_sd3.py
examples/controlnet/train_controlnet_sd3.py
+2
-13
No files found.
examples/controlnet/README_sd3.md
View file @
b0ffe922
...
@@ -104,7 +104,7 @@ from diffusers.utils import load_image
...
@@ -104,7 +104,7 @@ from diffusers.utils import load_image
import
torch
import
torch
base_model_path
=
"stabilityai/stable-diffusion-3-medium-diffusers"
base_model_path
=
"stabilityai/stable-diffusion-3-medium-diffusers"
controlnet_path
=
"sd3-controlnet-out
/checkpoint-6500/controlnet
"
controlnet_path
=
"
DavyMorgan/
sd3-controlnet-out"
controlnet
=
SD3ControlNetModel
.
from_pretrained
(
controlnet_path
,
torch_dtype
=
torch
.
float16
)
controlnet
=
SD3ControlNetModel
.
from_pretrained
(
controlnet_path
,
torch_dtype
=
torch
.
float16
)
pipe
=
StableDiffusion3ControlNetPipeline
.
from_pretrained
(
pipe
=
StableDiffusion3ControlNetPipeline
.
from_pretrained
(
...
...
examples/controlnet/train_controlnet_sd3.py
View file @
b0ffe922
...
@@ -50,7 +50,7 @@ from diffusers import (
...
@@ -50,7 +50,7 @@ from diffusers import (
)
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
free_memory
from
diffusers.training_utils
import
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
free_memory
from
diffusers.utils
import
check_min_version
,
is_wandb_available
from
diffusers.utils
import
check_min_version
,
is_wandb_available
,
make_image_grid
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.hub_utils
import
load_or_create_model_card
,
populate_model_card
from
diffusers.utils.torch_utils
import
is_compiled_module
from
diffusers.utils.torch_utils
import
is_compiled_module
...
@@ -64,17 +64,6 @@ check_min_version("0.30.0.dev0")
...
@@ -64,17 +64,6 @@ check_min_version("0.30.0.dev0")
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
def
image_grid
(
imgs
,
rows
,
cols
):
assert
len
(
imgs
)
==
rows
*
cols
w
,
h
=
imgs
[
0
].
size
grid
=
Image
.
new
(
"RGB"
,
size
=
(
cols
*
w
,
rows
*
h
))
for
i
,
img
in
enumerate
(
imgs
):
grid
.
paste
(
img
,
box
=
(
i
%
cols
*
w
,
i
//
cols
*
h
))
return
grid
def
log_validation
(
controlnet
,
args
,
accelerator
,
weight_dtype
,
step
,
is_final_validation
=
False
):
def
log_validation
(
controlnet
,
args
,
accelerator
,
weight_dtype
,
step
,
is_final_validation
=
False
):
logger
.
info
(
"Running validation... "
)
logger
.
info
(
"Running validation... "
)
...
@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
...
@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image
.
save
(
os
.
path
.
join
(
repo_folder
,
"image_control.png"
))
validation_image
.
save
(
os
.
path
.
join
(
repo_folder
,
"image_control.png"
))
img_str
+=
f
"prompt:
{
validation_prompt
}
\n
"
img_str
+=
f
"prompt:
{
validation_prompt
}
\n
"
images
=
[
validation_image
]
+
images
images
=
[
validation_image
]
+
images
image_grid
(
images
,
1
,
len
(
images
)).
save
(
os
.
path
.
join
(
repo_folder
,
f
"images_
{
i
}
.png"
))
make_
image_grid
(
images
,
1
,
len
(
images
)).
save
(
os
.
path
.
join
(
repo_folder
,
f
"images_
{
i
}
.png"
))
img_str
+=
f
"
\n
"
img_str
+=
f
"
\n
"
model_description
=
f
"""
model_description
=
f
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment