Unverified Commit 3d707773 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

[Sana-4K] (#10537)



* [Sana 4K]
add 4K support for Sana

* [Sana-4K] fix SanaPAGPipeline

* add VAE automatically tiling function;

* set clean_caption to False;

* add warnings for VAE OOM.

* style

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 6b727842
...@@ -16,6 +16,7 @@ import html ...@@ -16,6 +16,7 @@ import html
import inspect import inspect
import re import re
import urllib.parse as ul import urllib.parse as ul
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -41,6 +42,7 @@ from ..pixart_alpha.pipeline_pixart_alpha import ( ...@@ -41,6 +42,7 @@ from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_1024_BIN, ASPECT_RATIO_1024_BIN,
) )
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
from .pag_utils import PAGMixin from .pag_utils import PAGMixin
...@@ -639,7 +641,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -639,7 +641,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
clean_caption: bool = True, clean_caption: bool = False,
use_resolution_binning: bool = True, use_resolution_binning: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
...@@ -755,7 +757,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -755,7 +757,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if use_resolution_binning: if use_resolution_binning:
if self.transformer.config.sample_size == 64: if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 32: elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN aspect_ratio_bin = ASPECT_RATIO_1024_BIN
...@@ -912,7 +916,14 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -912,7 +916,14 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
image = latents image = latents
else: else:
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning: if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
......
...@@ -16,6 +16,7 @@ import html ...@@ -16,6 +16,7 @@ import html
import inspect import inspect
import re import re
import urllib.parse as ul import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -953,7 +954,14 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): ...@@ -953,7 +954,14 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image = latents image = latents
else: else:
latents = latents.to(self.vae.dtype) latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning: if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment