Unverified Commit ed759f0a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[PixArt-Alpha] Introduce resolution binning (#5739)



* feat: add resolution binning
Co-authored-by: default avatarlawrence-cj <jschen@mail.dlut.edu.cn>

* rename

* debug

* add :test

* remove unused variable

* set resolution_binning to False.

---------
Co-authored-by: default avatarlawrence-cj <jschen@mail.dlut.edu.cn>
parent 5b231aa3
...@@ -19,6 +19,7 @@ import urllib.parse as ul ...@@ -19,6 +19,7 @@ import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
from transformers import T5EncoderModel, T5Tokenizer from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
...@@ -43,7 +44,6 @@ if is_bs4_available(): ...@@ -43,7 +44,6 @@ if is_bs4_available():
if is_ftfy_available(): if is_ftfy_available():
import ftfy import ftfy
EXAMPLE_DOC_STRING = """ EXAMPLE_DOC_STRING = """
Examples: Examples:
```py ```py
...@@ -60,6 +60,42 @@ EXAMPLE_DOC_STRING = """ ...@@ -60,6 +60,42 @@ EXAMPLE_DOC_STRING = """
``` ```
""" """
ASPECT_RATIO_1024_BIN = {
"0.25": [512.0, 2048.0],
"0.28": [512.0, 1856.0],
"0.32": [576.0, 1792.0],
"0.33": [576.0, 1728.0],
"0.35": [576.0, 1664.0],
"0.4": [640.0, 1600.0],
"0.42": [640.0, 1536.0],
"0.48": [704.0, 1472.0],
"0.5": [704.0, 1408.0],
"0.52": [704.0, 1344.0],
"0.57": [768.0, 1344.0],
"0.6": [768.0, 1280.0],
"0.68": [832.0, 1216.0],
"0.72": [832.0, 1152.0],
"0.78": [896.0, 1152.0],
"0.82": [896.0, 1088.0],
"0.88": [960.0, 1088.0],
"0.94": [960.0, 1024.0],
"1.0": [1024.0, 1024.0],
"1.07": [1024.0, 960.0],
"1.13": [1088.0, 960.0],
"1.21": [1088.0, 896.0],
"1.29": [1152.0, 896.0],
"1.38": [1152.0, 832.0],
"1.46": [1216.0, 832.0],
"1.67": [1280.0, 768.0],
"1.75": [1344.0, 768.0],
"2.0": [1408.0, 704.0],
"2.09": [1472.0, 704.0],
"2.4": [1536.0, 640.0],
"2.5": [1600.0, 640.0],
"3.0": [1728.0, 576.0],
"4.0": [2048.0, 512.0],
}
class PixArtAlphaPipeline(DiffusionPipeline): class PixArtAlphaPipeline(DiffusionPipeline):
r""" r"""
...@@ -495,6 +531,38 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -495,6 +531,38 @@ class PixArtAlphaPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma latents = latents * self.scheduler.init_noise_sigma
return latents return latents
@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width."""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
return int(default_hw[0]), int(default_hw[1])
@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
orig_height, orig_width = samples.shape[2], samples.shape[3]
# Check if resizing is needed
if orig_height != new_height or orig_width != new_width:
ratio = max(new_height / orig_height, new_width / orig_width)
resized_width = int(orig_width * ratio)
resized_height = int(orig_height * ratio)
# Resize
samples = F.interpolate(
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
# Center Crop
start_x = (resized_width - new_width) // 2
end_x = start_x + new_width
start_y = (resized_height - new_height) // 2
end_y = start_y + new_height
samples = samples[:, :, start_y:end_y, start_x:end_x]
return samples
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -518,6 +586,7 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -518,6 +586,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
callback_steps: int = 1, callback_steps: int = 1,
clean_caption: bool = True, clean_caption: bool = True,
mask_feature: bool = True, mask_feature: bool = True,
use_resolution_binning: bool = True,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -580,6 +649,10 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -580,6 +649,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
be installed. If the dependencies are not installed, the embeddings will be created from the raw be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt. prompt.
mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
use_resolution_binning:
(`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the
closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images,
they are resized back to the requested resolution. Useful for generating non-square images.
Examples: Examples:
...@@ -591,6 +664,10 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -591,6 +664,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
orig_height, orig_width = height, width
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
self.check_inputs( self.check_inputs(
prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds
) )
...@@ -709,6 +786,8 @@ class PixArtAlphaPipeline(DiffusionPipeline): ...@@ -709,6 +786,8 @@ class PixArtAlphaPipeline(DiffusionPipeline):
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.resize_and_crop_tensor(image, orig_width, orig_height)
else: else:
image = latents image = latents
......
...@@ -89,7 +89,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -89,7 +89,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator, "generator": generator,
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 5.0, "guidance_scale": 5.0,
"output_type": "numpy", "use_resolution_binning": False,
"output_type": "np",
} }
return inputs return inputs
...@@ -120,6 +121,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -120,6 +121,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator, "generator": generator,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
"use_resolution_binning": False,
} }
# set all optional components to None # set all optional components to None
...@@ -154,6 +156,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -154,6 +156,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"generator": generator, "generator": generator,
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
"use_resolution_binning": False,
} }
output_loaded = pipe_loaded(**inputs)[0] output_loaded = pipe_loaded(**inputs)[0]
...@@ -189,8 +192,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -189,8 +192,8 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs, height=32, width=48).images image = pipe(**inputs, height=32, width=48).images
image_slice = image[0, -3:, -3:, -1] image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 32, 48, 3)) self.assertEqual(image.shape, (1, 32, 48, 3))
expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416]) expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
max_diff = np.abs(image_slice.flatten() - expected_slice).max() max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3) self.assertLessEqual(max_diff, 1e-3)
...@@ -219,6 +222,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -219,6 +222,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
"num_images_per_prompt": 2, "num_images_per_prompt": 2,
"use_resolution_binning": False,
} }
# set all optional components to None # set all optional components to None
...@@ -254,6 +258,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -254,6 +258,7 @@ class PixArtAlphaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"num_inference_steps": num_inference_steps, "num_inference_steps": num_inference_steps,
"output_type": output_type, "output_type": output_type,
"num_images_per_prompt": 2, "num_images_per_prompt": 2,
"use_resolution_binning": False,
} }
output_loaded = pipe_loaded(**inputs)[0] output_loaded = pipe_loaded(**inputs)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment