Unverified Commit 0028c344 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix SEGA pipeline (#8467)



* fix

* style

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d457beed
......@@ -376,6 +376,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
if editing_prompt:
enable_edit_guidance = True
......@@ -405,7 +406,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
......@@ -433,9 +434,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0]
else:
edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
......@@ -476,7 +477,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
......@@ -493,7 +494,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# get the initial random noise unless the user supplied it
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
......@@ -504,7 +505,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
height,
width,
text_embeddings.dtype,
self.device,
device,
generator,
latents,
)
......@@ -562,12 +563,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
if enable_edit_guidance:
concept_weights = torch.zeros(
(len(noise_pred_edit_concepts), noise_guidance.shape[0]),
device=self.device,
device=device,
dtype=noise_guidance.dtype,
)
noise_guidance_edit = torch.zeros(
(len(noise_pred_edit_concepts), *noise_guidance.shape),
device=self.device,
device=device,
dtype=noise_guidance.dtype,
)
# noise_guidance_edit = torch.zeros_like(noise_guidance)
......@@ -644,21 +645,19 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
warmup_inds = torch.tensor(warmup_inds).to(self.device)
warmup_inds = torch.tensor(warmup_inds).to(device)
if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
concept_weights = concept_weights.to("cpu") # Offload to cpu
noise_guidance_edit = noise_guidance_edit.to("cpu")
concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
concept_weights_tmp = torch.where(
concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
)
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
noise_guidance_edit_tmp = torch.index_select(
noise_guidance_edit.to(self.device), 0, warmup_inds
)
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
noise_guidance_edit_tmp = torch.einsum(
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
)
......@@ -669,8 +668,8 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
del noise_guidance_edit_tmp
del concept_weights_tmp
concept_weights = concept_weights.to(self.device)
noise_guidance_edit = noise_guidance_edit.to(self.device)
concept_weights = concept_weights.to(device)
noise_guidance_edit = noise_guidance_edit.to(device)
concept_weights = torch.where(
concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
......@@ -679,6 +678,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
concept_weights = torch.nan_to_num(concept_weights)
noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device)
noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
......@@ -689,7 +689,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
if sem_guidance is not None:
edit_guidance = sem_guidance[i].to(self.device)
edit_guidance = sem_guidance[i].to(device)
noise_guidance = noise_guidance + edit_guidance
noise_pred = noise_pred_uncond + noise_guidance
......@@ -705,7 +705,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
else:
image = latents
has_nsfw_concept = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment