"src/vscode:/vscode.git/clone" did not exist on "111a4aa754a344cb07a19a9b836a9add2b18a117"
Unverified Commit dbcb15c2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Stable UnCLIP] Finish Stable UnCLIP (#2814)

* up

* fix more 7

* up

* finish
parent c4892f18
...@@ -22,7 +22,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput ...@@ -22,7 +22,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
...@@ -178,6 +178,31 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -178,6 +178,31 @@ class StableUnCLIPPipeline(DiffusionPipeline):
if cpu_offloaded_model is not None: if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.prior_text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
...@@ -581,6 +606,7 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -581,6 +606,7 @@ class StableUnCLIPPipeline(DiffusionPipeline):
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
self.image_normalizer.to(image_embeds.device)
image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_normalizer.scale(image_embeds)
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
...@@ -884,6 +910,10 @@ class StableUnCLIPPipeline(DiffusionPipeline): ...@@ -884,6 +910,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
# 14. Post-processing # 14. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 15. Convert to PIL # 15. Convert to PIL
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -24,7 +24,7 @@ from diffusers.utils.import_utils import is_accelerate_available ...@@ -24,7 +24,7 @@ from diffusers.utils.import_utils import is_accelerate_available
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, randn_tensor, replace_example_docstring from ...utils import is_accelerate_version, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
...@@ -180,6 +180,31 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -180,6 +180,31 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
if cpu_offloaded_model is not None: if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.text_encoder, self.image_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
...@@ -548,6 +573,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -548,6 +573,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device) noise_level = torch.tensor([noise_level] * image_embeds.shape[0], device=image_embeds.device)
self.image_normalizer.to(image_embeds.device)
image_embeds = self.image_normalizer.scale(image_embeds) image_embeds = self.image_normalizer.scale(image_embeds)
image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise) image_embeds = self.image_noising_scheduler.add_noise(image_embeds, timesteps=noise_level, noise=noise)
...@@ -571,8 +597,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -571,8 +597,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None,
prompt: Union[str, List[str]] = None,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 20, num_inference_steps: int = 20,
...@@ -597,8 +623,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -597,8 +623,8 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, either `prompt_embeds` will be
instead. used or prompt is initialized to `""`.
image (`torch.FloatTensor` or `PIL.Image.Image`): image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which `Image`, or tensor representing an image batch. The image will be encoded to its CLIP embedding which
the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the the unet will be conditioned on. Note that the image is _not_ encoded by the vae and then used as the
...@@ -674,6 +700,9 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -674,6 +700,9 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
if prompt is None and prompt_embeds is None:
prompt = len(image) * [""] if isinstance(image, list) else ""
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
prompt=prompt, prompt=prompt,
...@@ -777,6 +806,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline): ...@@ -777,6 +806,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
# 9. Post-processing # 9. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 10. Convert to PIL # 10. Convert to PIL
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -37,6 +39,15 @@ class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin): ...@@ -37,6 +39,15 @@ class StableUnCLIPImageNormalizer(ModelMixin, ConfigMixin):
self.mean = nn.Parameter(torch.zeros(1, embedding_dim)) self.mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.std = nn.Parameter(torch.ones(1, embedding_dim)) self.std = nn.Parameter(torch.ones(1, embedding_dim))
def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
):
self.mean = nn.Parameter(self.mean.to(torch_device).to(torch_dtype))
self.std = nn.Parameter(self.std.to(torch_device).to(torch_dtype))
return self
def scale(self, embeds): def scale(self, embeds):
embeds = (embeds - self.mean) * 1.0 / self.std embeds = (embeds - self.mean) * 1.0 / self.std
return embeds return embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment