"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7ecfe2916044e4d1f2e5b7d8f2b519d6c4677e77"
Unverified Commit 0028c344 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix SEGA pipeline (#8467)



* fix

* style

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent d457beed
...@@ -376,6 +376,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -376,6 +376,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# 2. Define call parameters # 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt) batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
if editing_prompt: if editing_prompt:
enable_edit_guidance = True enable_edit_guidance = True
...@@ -405,7 +406,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -405,7 +406,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape bs_embed, seq_len, _ = text_embeddings.shape
...@@ -433,9 +434,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -433,9 +434,9 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0]
else: else:
edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed_edit, seq_len_edit, _ = edit_concepts.shape bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
...@@ -476,7 +477,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -476,7 +477,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
truncation=True, truncation=True,
return_tensors="pt", return_tensors="pt",
) )
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1] seq_len = uncond_embeddings.shape[1]
...@@ -493,7 +494,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -493,7 +494,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
# 4. Prepare timesteps # 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device) self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps timesteps = self.scheduler.timesteps
# 5. Prepare latent variables # 5. Prepare latent variables
...@@ -504,7 +505,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -504,7 +505,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
height, height,
width, width,
text_embeddings.dtype, text_embeddings.dtype,
self.device, device,
generator, generator,
latents, latents,
) )
...@@ -562,12 +563,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -562,12 +563,12 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
if enable_edit_guidance: if enable_edit_guidance:
concept_weights = torch.zeros( concept_weights = torch.zeros(
(len(noise_pred_edit_concepts), noise_guidance.shape[0]), (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
device=self.device, device=device,
dtype=noise_guidance.dtype, dtype=noise_guidance.dtype,
) )
noise_guidance_edit = torch.zeros( noise_guidance_edit = torch.zeros(
(len(noise_pred_edit_concepts), *noise_guidance.shape), (len(noise_pred_edit_concepts), *noise_guidance.shape),
device=self.device, device=device,
dtype=noise_guidance.dtype, dtype=noise_guidance.dtype,
) )
# noise_guidance_edit = torch.zeros_like(noise_guidance) # noise_guidance_edit = torch.zeros_like(noise_guidance)
...@@ -644,21 +645,19 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -644,21 +645,19 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp
warmup_inds = torch.tensor(warmup_inds).to(self.device) warmup_inds = torch.tensor(warmup_inds).to(device)
if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
concept_weights = concept_weights.to("cpu") # Offload to cpu concept_weights = concept_weights.to("cpu") # Offload to cpu
noise_guidance_edit = noise_guidance_edit.to("cpu") noise_guidance_edit = noise_guidance_edit.to("cpu")
concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
concept_weights_tmp = torch.where( concept_weights_tmp = torch.where(
concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
) )
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
noise_guidance_edit_tmp = torch.index_select( noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
noise_guidance_edit.to(self.device), 0, warmup_inds
)
noise_guidance_edit_tmp = torch.einsum( noise_guidance_edit_tmp = torch.einsum(
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
) )
...@@ -669,8 +668,8 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -669,8 +668,8 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
del noise_guidance_edit_tmp del noise_guidance_edit_tmp
del concept_weights_tmp del concept_weights_tmp
concept_weights = concept_weights.to(self.device) concept_weights = concept_weights.to(device)
noise_guidance_edit = noise_guidance_edit.to(self.device) noise_guidance_edit = noise_guidance_edit.to(device)
concept_weights = torch.where( concept_weights = torch.where(
concept_weights < 0, torch.zeros_like(concept_weights), concept_weights concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
...@@ -679,6 +678,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -679,6 +678,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
concept_weights = torch.nan_to_num(concept_weights) concept_weights = torch.nan_to_num(concept_weights)
noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device)
noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
...@@ -689,7 +689,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -689,7 +689,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
self.sem_guidance[i] = noise_guidance_edit.detach().cpu() self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
if sem_guidance is not None: if sem_guidance is not None:
edit_guidance = sem_guidance[i].to(self.device) edit_guidance = sem_guidance[i].to(device)
noise_guidance = noise_guidance + edit_guidance noise_guidance = noise_guidance + edit_guidance
noise_pred = noise_pred_uncond + noise_guidance noise_pred = noise_pred_uncond + noise_guidance
...@@ -705,7 +705,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): ...@@ -705,7 +705,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
# 8. Post-processing # 8. Post-processing
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
else: else:
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment