Unverified Commit d67eba0f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Utility] adds an image grid utility (#4576)

* add: utility for image grid.

* add: return type.

* change necessary places.

* add to utility page.
parent 714bfed8
......@@ -20,4 +20,8 @@ Utility and helper functions for working with 🤗 Diffusers.
## export_to_video
[[autodoc]] utils.testing_utils.export_to_video
\ No newline at end of file
[[autodoc]] utils.testing_utils.export_to_video
## make_image_grid
[[autodoc]] utils.pil_utils.make_image_grid
\ No newline at end of file
......@@ -152,26 +152,13 @@ def get_inputs(batch_size=1):
return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
```
You'll also need a function that'll display each batch of images:
```python
from PIL import Image
def image_grid(imgs, rows=2, cols=2):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
```
Start with `batch_size=4` and see how much memory you've consumed:
```python
from diffusers.utils import make_image_grid
images = pipeline(**get_inputs(batch_size=4)).images
image_grid(images)
make_image_grid(images, 2, 2)
```
Unless you have a GPU with more RAM, the code above probably returned an `OOM` error! Most of the memory is taken up by the cross-attention layers. Instead of running this operation in a batch, you can run it sequentially to save a significant amount of memory. All you have to do is configure the pipeline to use the [`~DiffusionPipeline.enable_attention_slicing`] function:
......@@ -184,7 +171,7 @@ Now try increasing the `batch_size` to 8!
```python
images = pipeline(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4)
make_image_grid(images, rows=2, cols=4)
```
<div class="flex justify-center">
......@@ -213,7 +200,7 @@ from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
pipeline.vae = vae
images = pipeline(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4)
make_image_grid(images, rows=2, cols=4)
```
<div class="flex justify-center">
......@@ -238,7 +225,7 @@ Generate a batch of images with the new prompt:
```python
images = pipeline(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4)
make_image_grid(images, rows=2, cols=4)
```
<div class="flex justify-center">
......@@ -257,7 +244,7 @@ prompts = [
generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
image_grid(images)
make_image_grid(images, 2, 2)
```
<div class="flex justify-center">
......
......@@ -252,18 +252,11 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [
```py
>>> from diffusers import DDPMPipeline
>>> from diffusers.utils import make_image_grid
>>> import math
>>> import os
>>> def make_grid(images, rows, cols):
... w, h = images[0].size
... grid = Image.new("RGB", size=(cols * w, rows * h))
... for i, image in enumerate(images):
... grid.paste(image, box=(i % cols * w, i // cols * h))
... return grid
>>> def evaluate(config, epoch, pipeline):
... # Sample some images from random noise (this is the backward diffusion process).
... # The default pipeline output type is `List[PIL.Image]`
......@@ -273,7 +266,7 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [
... ).images
... # Make a grid out of the images
... image_grid = make_grid(images, rows=4, cols=4)
... image_grid = make_image_grid(images, rows=4, cols=4)
... # Save the images
... test_dir = os.path.join(config.output_dir, "samples")
......
......@@ -175,22 +175,12 @@ images = pipeline(
).images
```
Finally, create a helper function to display the images:
Display the images:
```py
from PIL import Image
from diffusers.utils import make_image_grid
def image_grid(imgs, rows=2, cols=2):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
image_grid(images)
make_image_grid(images, 2, 2)
```
<div class="flex justify-center">
......
......@@ -153,19 +153,10 @@ images = pipeline.numpy_to_pil(images)
### Visualization
Let's create a helper function to display images in a grid.
```python
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
```
from diffusers import make_image_grid
```python
image_grid(images, 2, 4)
make_image_grid(images, 2, 4)
```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)
......@@ -198,7 +189,7 @@ images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)
make_image_grid(images, 2, 4)
```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
......
......@@ -14,7 +14,7 @@ from huggingface_hub import notebook_login
notebook_login()
```
Import the necessary libraries, and create a helper function to visualize the generated images:
Import the necessary libraries:
```py
import os
......@@ -24,19 +24,8 @@ import PIL
from PIL import Image
from diffusers import StableDiffusionPipeline
from diffusers.utils import make_image_grid
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
```
Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer):
......@@ -73,7 +62,7 @@ for _ in range(num_rows):
images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images
all_images.extend(images)
grid = image_grid(all_images, num_samples, num_rows)
grid = make_image_grid(all_images, num_samples, num_rows)
grid
```
......
......@@ -47,7 +47,7 @@ from diffusers import (
FlaxStableDiffusionControlNetPipeline,
FlaxUNet2DConditionModel,
)
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
......@@ -64,18 +64,6 @@ check_min_version("0.20.0.dev0")
logger = logging.getLogger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
logger.info("Running validation...")
......@@ -154,7 +142,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
yaml = f"""
......
......@@ -50,7 +50,7 @@ from diffusers import (
UniPCMultistepScheduler,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available
......@@ -63,17 +63,6 @@ check_min_version("0.20.0.dev0")
logger = get_logger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
......@@ -205,7 +194,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
yaml = f"""
......
......@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import make_image_grid
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
......@@ -427,19 +428,6 @@ def freeze_params(params):
param.requires_grad = False
def image_grid(imgs, rows, cols):
if not len(imgs) == rows * cols:
raise ValueError("The specified number of rows and columns are not correct.")
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42):
generator = torch.Generator(pipeline.device).manual_seed(seed)
images = pipeline(
......@@ -450,7 +438,7 @@ def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps
num_images_per_prompt=num_images_per_prompt,
).images
_rows = int(math.sqrt(num_images_per_prompt))
grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
grid = make_image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
return grid
......
......@@ -35,7 +35,6 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
......@@ -45,7 +44,7 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available
......@@ -63,17 +62,6 @@ DATASET_NAME_MAPPING = {
}
def make_image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def save_model_card(
args,
repo_id: str,
......
......@@ -51,19 +51,11 @@ EXAMPLE_DOC_STRING = """
>>> import jax.numpy as jnp
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> from diffusers.utils import load_image
>>> from diffusers.utils import load_image, make_image_grid
>>> from PIL import Image
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
>>> def image_grid(imgs, rows, cols):
... w, h = imgs[0].size
... grid = Image.new("RGB", size=(cols * w, rows * h))
... for i, img in enumerate(imgs):
... grid.paste(img, box=(i % cols * w, i // cols * h))
... return grid
>>> def create_key(seed=0):
... return jax.random.PRNGKey(seed)
......@@ -110,7 +102,7 @@ EXAMPLE_DOC_STRING = """
... ).images
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
>>> output_images = image_grid(output_images, num_samples // 4, 4)
>>> output_images = make_image_grid(output_images, num_samples // 4, 4)
>>> output_images.save("generated_image.png")
```
"""
......
......@@ -79,7 +79,7 @@ from .import_utils import (
)
from .logging import get_logger
from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .torch_utils import is_compiled_module, randn_tensor
......
from typing import List
import PIL.Image
import PIL.ImageOps
from packaging import version
......@@ -46,3 +48,20 @@ def numpy_to_pil(images):
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image:
"""
Prepares a single grid of images. Useful for visualization purposes.
"""
assert len(images) == rows * cols
if resize is not None:
images = [img.resize((resize, resize)) for img in images]
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment