Unverified Commit 3d707773 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

[Sana-4K] (#10537)



* [Sana 4K]
add 4K support for Sana

* [Sana-4K] fix SanaPAGPipeline

* add VAE automatically tiling function;

* set clean_caption to False;

* add warnings for VAE OOM.

* style

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 6b727842
......@@ -16,6 +16,7 @@ import html
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
......@@ -41,6 +42,7 @@ from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_1024_BIN,
)
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
from .pag_utils import PAGMixin
......@@ -639,7 +641,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
clean_caption: bool = True,
clean_caption: bool = False,
use_resolution_binning: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
......@@ -755,7 +757,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if use_resolution_binning:
if self.transformer.config.sample_size == 64:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
......@@ -912,7 +916,14 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
......
......@@ -16,6 +16,7 @@ import html
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
......@@ -953,7 +954,14 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image = latents
else:
latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment