Unverified Commit b0ffe922 authored by Yu Zheng's avatar Yu Zheng Committed by GitHub
Browse files

Update sd3 controlnet example (#9735)

* use make_image_grid in diffusers.utils

* use checkpoint on the Hub
parent 1b64772b
......@@ -104,7 +104,7 @@ from diffusers.utils import load_image
import torch
base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet"
controlnet_path = "DavyMorgan/sd3-controlnet-out"
controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
......
......@@ -50,7 +50,7 @@ from diffusers import (
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
......@@ -64,17 +64,6 @@ check_min_version("0.30.0.dev0")
logger = get_logger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
logger.info("Running validation... ")
......@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment