Unverified Commit 7d8b4f7f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Paint by example] Fix cpu offload for paint by example (#2062)

* [Paint by example] Fix paint by example

* fix more

* final fix
parent a66f2bae
...@@ -36,12 +36,15 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel): ...@@ -36,12 +36,15 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
# uncondition for scaling # uncondition for scaling
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
def forward(self, pixel_values): def forward(self, pixel_values, return_uncond_vector=False):
clip_output = self.model(pixel_values=pixel_values) clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output latent_states = clip_output.pooler_output
latent_states = self.mapper(latent_states[:, None]) latent_states = self.mapper(latent_states[:, None])
latent_states = self.final_layer_norm(latent_states) latent_states = self.final_layer_norm(latent_states)
latent_states = self.proj_out(latent_states) latent_states = self.proj_out(latent_states)
if return_uncond_vector:
return latent_states, self.uncond_vector
return latent_states return latent_states
......
...@@ -201,14 +201,11 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -201,14 +201,11 @@ class PaintByExamplePipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.vae]: for cpu_offloaded_model in [self.unet, self.vae, self.image_encoder]:
if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, execution_device=device)
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None: if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
...@@ -367,7 +364,7 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -367,7 +364,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image) image_embeddings, uncond_embeddings = self.image_encoder(image, return_uncond_vector=True)
# duplicate image embeddings for each generation per prompt, using mps friendly method # duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape bs_embed, seq_len, _ = image_embeddings.shape
...@@ -375,7 +372,6 @@ class PaintByExamplePipeline(DiffusionPipeline): ...@@ -375,7 +372,6 @@ class PaintByExamplePipeline(DiffusionPipeline):
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance: if do_classifier_free_guidance:
uncond_embeddings = self.image_encoder.uncond_vector
uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1) uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1) uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment