Unverified Commit 57f7d259 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[CPU offload] correct cpu offload (#1968)



* [CPU offload] correct cpu offload

* [CPU offload] correct cpu offload

* finish

* finish

* Update docs/source/en/optimization/fp16.mdx
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 50b65135
...@@ -149,7 +149,7 @@ You may see a small performance boost in VAE decode on multi-image batches. Ther ...@@ -149,7 +149,7 @@ You may see a small performance boost in VAE decode on multi-image batches. Ther
## Offloading to CPU with accelerate for memory savings ## Offloading to CPU with accelerate for memory savings
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass. For additional memory savings, you can offload the weights to CPU and only load them to GPU when performing the forward pass.
To perform CPU offloading, all you have to do is invoke [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]: To perform CPU offloading, all you have to do is invoke [`~StableDiffusionPipeline.enable_sequential_cpu_offload`]:
...@@ -162,16 +162,15 @@ pipe = StableDiffusionPipeline.from_pretrained( ...@@ -162,16 +162,15 @@ pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
image = pipe(prompt).images[0] image = pipe(prompt).images[0]
``` ```
And you can get the memory consumption to < 2GB. And you can get the memory consumption to < 3GB.
If is also possible to chain it with attention slicing for minimal memory consumption, running it in as little as < 800mb of GPU vRAM: If is also possible to chain it with attention slicing for minimal memory consumption (< 2GB).
```Python ```Python
import torch import torch
...@@ -182,7 +181,6 @@ pipe = StableDiffusionPipeline.from_pretrained( ...@@ -182,7 +181,6 @@ pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars" prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
...@@ -191,6 +189,8 @@ pipe.enable_attention_slicing(1) ...@@ -191,6 +189,8 @@ pipe.enable_attention_slicing(1)
image = pipe(prompt).images[0] image = pipe(prompt).images[0]
``` ```
**Note**: When using `enable_sequential_cpu_offload()`, it is important to **not** move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal. See [this issue](https://github.com/huggingface/diffusers/issues/1934) for more information.
## Using Channels Last memory format ## Using Channels Last memory format
Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model. Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model.
......
...@@ -211,13 +211,10 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -211,13 +211,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
def _execution_device(self): def _execution_device(self):
......
...@@ -233,13 +233,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -233,13 +233,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
def _execution_device(self): def _execution_device(self):
......
...@@ -236,13 +236,10 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -236,13 +236,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
......
...@@ -208,13 +208,10 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -208,13 +208,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
def _execution_device(self): def _execution_device(self):
......
...@@ -238,13 +238,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -238,13 +238,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
......
...@@ -272,13 +272,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -272,13 +272,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
......
...@@ -205,13 +205,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -205,13 +205,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
......
...@@ -137,13 +137,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): ...@@ -137,13 +137,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment