Unverified Commit dbcb15c2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Stable UnCLIP] Finish Stable UnCLIP (#2814)

* up

* fix more 7

* up

* finish
parent c4892f18
......@@ -22,7 +22,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
......@@ -178,6 +178,31 @@ class StableUnCLIPPipeline(DiffusionPipeline):
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
......@@ -581,6 +606,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
self.image_normalizer.to(image_embeds.device)
image_embeds = self.image_normalizer.scale(image_embeds)
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
......@@ -884,6 +910,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
# 14. Post-processing
image = self.decode_latents(latents)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 15. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
......
......@@ -24,7 +24,7 @@ from diffusers.utils.import_utils import is_accelerate_available
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor, replace_example_docstring
from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
......@@ -180,6 +180,31 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
......@@ -548,6 +573,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
self.image_normalizer.to(image_embeds.device)
image_embeds = self.image_normalizer.scale(image_embeds)
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
......@@ -571,8 +597,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 20,
......@@ -597,8 +623,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
used or prompt is initialized to `""`.
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
......@@ -674,6 +700,9 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if prompt is None and prompt_embeds is None:
prompt = len(image) * [""] if isinstance(image, list) else ""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
......@@ -777,6 +806,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
# 9. Post-processing
image = self.decode_latents(latents)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import torch
from torch import nn
......@@ -37,6 +39,15 @@ class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin):
self.mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.std = nn.Parameter(torch.ones(1, embedding_dim))
def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
):
self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype))
self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype))
return self
def scale(self, embeds):
embeds = (embeds - self.mean) * 1.0 / self.std
return embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment