Unverified Commit c28d3c82 authored by Ilmari Heikkinen's avatar Ilmari Heikkinen Committed by GitHub
Browse files

StableDiffusion: Decode latents separately to run larger batches (#1150)



* StableDiffusion: Decode latents separately to run larger batches

* Move VAE sliced decode under enable_vae_sliced_decode and vae.enable_sliced_decode

* Rename sliced_decode to slicing

* fix whitespace

* fix quality check and repository consistency

* VAE slicing tests and documentation

* API doc hooks for VAE slicing

* reformat vae slicing tests

* Skip VAE slicing for one-image batches

* Documentation tweaks for VAE slicing
Co-authored-by: default avatarIlmari Heikkinen <ilmari@fhtr.org>
parent bcb6cc16
...@@ -76,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca ...@@ -76,6 +76,8 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- __call__ - __call__
- enable_attention_slicing - enable_attention_slicing
- disable_attention_slicing - disable_attention_slicing
- enable_vae_slicing
- disable_vae_slicing
## StableDiffusionImg2ImgPipeline ## StableDiffusionImg2ImgPipeline
[[autodoc]] StableDiffusionImg2ImgPipeline [[autodoc]] StableDiffusionImg2ImgPipeline
......
...@@ -117,6 +117,34 @@ image = pipe(prompt).images[0] ...@@ -117,6 +117,34 @@ image = pipe(prompt).images[0]
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM! There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
## Sliced VAE decode for larger batches
To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.
You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:
```Python
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
images = pipe([prompt] * 32).images
```
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
## Offloading to CPU with accelerate for memory savings ## Offloading to CPU with accelerate for memory savings
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass. For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.
......
...@@ -565,6 +565,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -565,6 +565,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x) h = self.encoder(x)
...@@ -576,7 +577,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -576,7 +577,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return AutoencoderKLOutput(latent_dist=posterior) return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z) dec = self.decoder(z)
...@@ -585,6 +586,34 @@ class AutoencoderKL(ModelMixin, ConfigMixin): ...@@ -585,6 +586,34 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
......
...@@ -216,6 +216,22 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -216,6 +216,22 @@ class AltDiffusionPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -215,6 +215,22 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -215,6 +215,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing` # set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None) self.enable_attention_slicing(None)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -557,6 +557,46 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -557,6 +557,46 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4 assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
def test_stable_diffusion_vae_slicing(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
image_count = 4
generator = torch.Generator(device=device).manual_seed(0)
output_1 = sd_pipe(
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
)
# make sure sliced vae decode yields the same result
sd_pipe.enable_vae_slicing()
generator = torch.Generator(device=device).manual_seed(0)
output_2 = sd_pipe(
[prompt] * image_count, generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np"
)
# there is a small discrepancy at image borders vs. full batch decode
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 3e-3
def test_stable_diffusion_negative_prompt(self): def test_stable_diffusion_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet unet = self.dummy_cond_unet
...@@ -886,6 +926,45 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -886,6 +926,45 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
assert mem_bytes > 3.75 * 10**9 assert mem_bytes > 3.75 * 10**9
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3 assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
def test_stable_diffusion_vae_slicing(self):
torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "a photograph of an astronaut riding a horse"
# enable vae slicing
pipe.enable_vae_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output_chunked = pipe(
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image_chunked = output_chunked.images
mem_bytes = torch.cuda.max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# make sure that less than 4 GB is allocated
assert mem_bytes < 4e9
# disable vae slicing
pipe.disable_vae_slicing()
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast(torch_device):
output = pipe(
[prompt] * 4, generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
)
image = output.images
# make sure that more than 4 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes > 4e9
# There is a small discrepancy at the image borders vs. a fully batched version.
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 3e-3
def test_stable_diffusion_text2img_pipeline_fp16(self): def test_stable_diffusion_text2img_pipeline_fp16(self):
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment