Unverified Commit 7186bb45 authored by hlky's avatar hlky Committed by GitHub
Browse files

Add enable_vae_tiling to AllegroPipeline, fix example (#10212)

parent 438bd605
...@@ -59,6 +59,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -59,6 +59,7 @@ EXAMPLE_DOC_STRING = """
>>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32)
>>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda")
>>> pipe.enable_vae_tiling()
>>> prompt = ( >>> prompt = (
... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
...@@ -636,6 +637,35 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -636,6 +637,35 @@ class AllegroPipeline(DiffusionPipeline):
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
@property @property
def guidance_scale(self): def guidance_scale(self):
return self._guidance_scale return self._guidance_scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment