Unverified Commit b0ffe922 authored by Yu Zheng's avatar Yu Zheng Committed by GitHub
Browse files

Update sd3 controlnet example (#9735)

* use make_image_grid in diffusers.utils

* use checkpoint on the Hub
parent 1b64772b
...@@ -104,7 +104,7 @@ from diffusers.utils import load_image ...@@ -104,7 +104,7 @@ from diffusers.utils import load_image
import torch import torch
base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers" base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet" controlnet_path = "DavyMorgan/sd3-controlnet-out"
controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained( pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
......
...@@ -50,7 +50,7 @@ from diffusers import ( ...@@ -50,7 +50,7 @@ from diffusers import (
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -64,17 +64,6 @@ check_min_version("0.30.0.dev0") ...@@ -64,17 +64,6 @@ check_min_version("0.30.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
logger.info("Running validation... ") logger.info("Running validation... ")
...@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N ...@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image.save(os.path.join(repo_folder, "image_control.png")) validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n" img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n" img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f""" model_description = f"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment