Unverified Commit d67eba0f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Utility] adds an image grid utility (#4576)

* add: utility for image grid.

* add: return type.

* change necessary places.

* add to utility page.
parent 714bfed8
...@@ -20,4 +20,8 @@ Utility and helper functions for working with 🤗 Diffusers. ...@@ -20,4 +20,8 @@ Utility and helper functions for working with 🤗 Diffusers.
## export_to_video ## export_to_video
[[autodoc]] utils.testing_utils.export_to_video [[autodoc]] utils.testing_utils.export_to_video
\ No newline at end of file
## make_image_grid
[[autodoc]] utils.pil_utils.make_image_grid
\ No newline at end of file
...@@ -152,26 +152,13 @@ def get_inputs(batch_size=1): ...@@ -152,26 +152,13 @@ def get_inputs(batch_size=1):
return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps} return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
``` ```
You'll also need a function that'll display each batch of images:
```python
from PIL import Image
def image_grid(imgs, rows=2, cols=2):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
```
Start with `batch_size=4` and see how much memory you've consumed: Start with `batch_size=4` and see how much memory you've consumed:
```python ```python
from diffusers.utils import make_image_grid
images = pipeline(**get_inputs(batch_size=4)).images images = pipeline(**get_inputs(batch_size=4)).images
image_grid(images) make_image_grid(images, 2, 2)
``` ```
Unless you have a GPU with more RAM, the code above probably returned an `OOM` error! Most of the memory is taken up by the cross-attention layers. Instead of running this operation in a batch, you can run it sequentially to save a significant amount of memory. All you have to do is configure the pipeline to use the [`~DiffusionPipeline.enable_attention_slicing`] function: Unless you have a GPU with more RAM, the code above probably returned an `OOM` error! Most of the memory is taken up by the cross-attention layers. Instead of running this operation in a batch, you can run it sequentially to save a significant amount of memory. All you have to do is configure the pipeline to use the [`~DiffusionPipeline.enable_attention_slicing`] function:
...@@ -184,7 +171,7 @@ Now try increasing the `batch_size` to 8! ...@@ -184,7 +171,7 @@ Now try increasing the `batch_size` to 8!
```python ```python
images = pipeline(**get_inputs(batch_size=8)).images images = pipeline(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4) make_image_grid(images, rows=2, cols=4)
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
...@@ -213,7 +200,7 @@ from diffusers import AutoencoderKL ...@@ -213,7 +200,7 @@ from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda") vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
pipeline.vae = vae pipeline.vae = vae
images = pipeline(**get_inputs(batch_size=8)).images images = pipeline(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4) make_image_grid(images, rows=2, cols=4)
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
...@@ -238,7 +225,7 @@ Generate a batch of images with the new prompt: ...@@ -238,7 +225,7 @@ Generate a batch of images with the new prompt:
```python ```python
images = pipeline(**get_inputs(batch_size=8)).images images = pipeline(**get_inputs(batch_size=8)).images
image_grid(images, rows=2, cols=4) make_image_grid(images, rows=2, cols=4)
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
...@@ -257,7 +244,7 @@ prompts = [ ...@@ -257,7 +244,7 @@ prompts = [
generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))] generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
image_grid(images) make_image_grid(images, 2, 2)
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
......
...@@ -252,18 +252,11 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [ ...@@ -252,18 +252,11 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [
```py ```py
>>> from diffusers import DDPMPipeline >>> from diffusers import DDPMPipeline
>>> from diffusers.utils import make_image_grid
>>> import math >>> import math
>>> import os >>> import os
>>> def make_grid(images, rows, cols):
... w, h = images[0].size
... grid = Image.new("RGB", size=(cols * w, rows * h))
... for i, image in enumerate(images):
... grid.paste(image, box=(i % cols * w, i // cols * h))
... return grid
>>> def evaluate(config, epoch, pipeline): >>> def evaluate(config, epoch, pipeline):
... # Sample some images from random noise (this is the backward diffusion process). ... # Sample some images from random noise (this is the backward diffusion process).
... # The default pipeline output type is `List[PIL.Image]` ... # The default pipeline output type is `List[PIL.Image]`
...@@ -273,7 +266,7 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [ ...@@ -273,7 +266,7 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [
... ).images ... ).images
... # Make a grid out of the images ... # Make a grid out of the images
... image_grid = make_grid(images, rows=4, cols=4) ... image_grid = make_image_grid(images, rows=4, cols=4)
... # Save the images ... # Save the images
... test_dir = os.path.join(config.output_dir, "samples") ... test_dir = os.path.join(config.output_dir, "samples")
......
...@@ -175,22 +175,12 @@ images = pipeline( ...@@ -175,22 +175,12 @@ images = pipeline(
).images ).images
``` ```
Finally, create a helper function to display the images: Display the images:
```py ```py
from PIL import Image from diffusers.utils import make_image_grid
make_image_grid(images, 2, 2)
def image_grid(imgs, rows=2, cols=2):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
image_grid(images)
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
......
...@@ -153,19 +153,10 @@ images = pipeline.numpy_to_pil(images) ...@@ -153,19 +153,10 @@ images = pipeline.numpy_to_pil(images)
### Visualization ### Visualization
Let's create a helper function to display images in a grid.
```python ```python
def image_grid(imgs, rows, cols): from diffusers import make_image_grid
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
```
```python make_image_grid(images, 2, 4)
image_grid(images, 2, 4)
``` ```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg) ![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)
...@@ -198,7 +189,7 @@ images = pipeline(prompt_ids, p_params, rng, jit=True).images ...@@ -198,7 +189,7 @@ images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images) images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4) make_image_grid(images, 2, 4)
``` ```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg) ![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
......
...@@ -14,7 +14,7 @@ from huggingface_hub import notebook_login ...@@ -14,7 +14,7 @@ from huggingface_hub import notebook_login
notebook_login() notebook_login()
``` ```
Import the necessary libraries, and create a helper function to visualize the generated images: Import the necessary libraries:
```py ```py
import os import os
...@@ -24,19 +24,8 @@ import PIL ...@@ -24,19 +24,8 @@ import PIL
from PIL import Image from PIL import Image
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.utils import make_image_grid
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
``` ```
Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer): Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer):
...@@ -73,7 +62,7 @@ for _ in range(num_rows): ...@@ -73,7 +62,7 @@ for _ in range(num_rows):
images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images
all_images.extend(images) all_images.extend(images)
grid = image_grid(all_images, num_samples, num_rows) grid = make_image_grid(all_images, num_samples, num_rows)
grid grid
``` ```
......
...@@ -47,7 +47,7 @@ from diffusers import ( ...@@ -47,7 +47,7 @@ from diffusers import (
FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionControlNetPipeline,
FlaxUNet2DConditionModel, FlaxUNet2DConditionModel,
) )
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image # To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
...@@ -64,18 +64,6 @@ check_min_version("0.20.0.dev0") ...@@ -64,18 +64,6 @@ check_min_version("0.20.0.dev0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype): def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
logger.info("Running validation...") logger.info("Running validation...")
...@@ -154,7 +142,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N ...@@ -154,7 +142,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image.save(os.path.join(repo_folder, "image_control.png")) validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n" img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n" img_str += f"![images_{i})](./images_{i}.png)\n"
yaml = f""" yaml = f"""
......
...@@ -50,7 +50,7 @@ from diffusers import ( ...@@ -50,7 +50,7 @@ from diffusers import (
UniPCMultistepScheduler, UniPCMultistepScheduler,
) )
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -63,17 +63,6 @@ check_min_version("0.20.0.dev0") ...@@ -63,17 +63,6 @@ check_min_version("0.20.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step): def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ") logger.info("Running validation... ")
...@@ -205,7 +194,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N ...@@ -205,7 +194,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image.save(os.path.join(repo_folder, "image_control.png")) validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n" img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images images = [validation_image] + images
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n" img_str += f"![images_{i})](./images_{i}.png)\n"
yaml = f""" yaml = f"""
......
...@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer ...@@ -24,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import make_image_grid
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
...@@ -427,19 +428,6 @@ def freeze_params(params): ...@@ -427,19 +428,6 @@ def freeze_params(params):
param.requires_grad = False param.requires_grad = False
def image_grid(imgs, rows, cols):
if not len(imgs) == rows * cols:
raise ValueError("The specified number of rows and columns are not correct.")
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42): def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps=50, num_images_per_prompt=1, seed=42):
generator = torch.Generator(pipeline.device).manual_seed(seed) generator = torch.Generator(pipeline.device).manual_seed(seed)
images = pipeline( images = pipeline(
...@@ -450,7 +438,7 @@ def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps ...@@ -450,7 +438,7 @@ def generate_images(pipeline, prompt="", guidance_scale=7.5, num_inference_steps
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
).images ).images
_rows = int(math.sqrt(num_images_per_prompt)) _rows = int(math.sqrt(num_images_per_prompt))
grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows) grid = make_image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
return grid return grid
......
...@@ -35,7 +35,6 @@ from accelerate.utils import ProjectConfiguration, set_seed ...@@ -35,7 +35,6 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from PIL import Image
from torchvision import transforms from torchvision import transforms
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
...@@ -45,7 +44,7 @@ import diffusers ...@@ -45,7 +44,7 @@ import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
...@@ -63,17 +62,6 @@ DATASET_NAME_MAPPING = { ...@@ -63,17 +62,6 @@ DATASET_NAME_MAPPING = {
} }
def make_image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def save_model_card( def save_model_card(
args, args,
repo_id: str, repo_id: str,
......
...@@ -51,19 +51,11 @@ EXAMPLE_DOC_STRING = """ ...@@ -51,19 +51,11 @@ EXAMPLE_DOC_STRING = """
>>> import jax.numpy as jnp >>> import jax.numpy as jnp
>>> from flax.jax_utils import replicate >>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard >>> from flax.training.common_utils import shard
>>> from diffusers.utils import load_image >>> from diffusers.utils import load_image, make_image_grid
>>> from PIL import Image >>> from PIL import Image
>>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
>>> def image_grid(imgs, rows, cols):
... w, h = imgs[0].size
... grid = Image.new("RGB", size=(cols * w, rows * h))
... for i, img in enumerate(imgs):
... grid.paste(img, box=(i % cols * w, i // cols * h))
... return grid
>>> def create_key(seed=0): >>> def create_key(seed=0):
... return jax.random.PRNGKey(seed) ... return jax.random.PRNGKey(seed)
...@@ -110,7 +102,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -110,7 +102,7 @@ EXAMPLE_DOC_STRING = """
... ).images ... ).images
>>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
>>> output_images = image_grid(output_images, num_samples // 4, 4) >>> output_images = make_image_grid(output_images, num_samples // 4, 4)
>>> output_images.save("generated_image.png") >>> output_images.save("generated_image.png")
``` ```
""" """
......
...@@ -79,7 +79,7 @@ from .import_utils import ( ...@@ -79,7 +79,7 @@ from .import_utils import (
) )
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION, numpy_to_pil, pt_to_pil from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .torch_utils import is_compiled_module, randn_tensor from .torch_utils import is_compiled_module, randn_tensor
......
from typing import List
import PIL.Image import PIL.Image
import PIL.ImageOps import PIL.ImageOps
from packaging import version from packaging import version
...@@ -46,3 +48,20 @@ def numpy_to_pil(images): ...@@ -46,3 +48,20 @@ def numpy_to_pil(images):
pil_images = [Image.fromarray(image) for image in images] pil_images = [Image.fromarray(image) for image in images]
return pil_images return pil_images
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image:
"""
Prepares a single grid of images. Useful for visualization purposes.
"""
assert len(images) == rows * cols
if resize is not None:
images = [img.resize((resize, resize)) for img in images]
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment