Unverified Commit b0c89738 authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

[Sana 4K] Add vae tiling option to avoid OOM (#10583)


Co-authored-by: default avatarJ石页 <jiangshuo9@h-partners.com>
parent c944f065
......@@ -158,6 +158,9 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
if args.enable_vae_tiling:
pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)
pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
......@@ -597,6 +600,7 @@ def parse_args(input_args=None):
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
if input_args is not None:
args = parser.parse_args(input_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment