Unverified Commit 6766a811 authored by Aki Sakurai's avatar Aki Sakurai Committed by GitHub
Browse files

Support non square image generation for StableDiffusionSAGPipeline (#2629)

* Support non square image generation for StableDiffusionSAGPipeline

* Fix style
parent bbab8553
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
...@@ -606,64 +605,73 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -606,64 +605,73 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
store_processor = CrossAttnStoreProcessor() store_processor = CrossAttnStoreProcessor()
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): map_size = None
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents def get_map_size(module, input, output):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) nonlocal map_size
map_size = output.sample.shape[-2:]
# predict the noise residual
noise_pred = self.unet( with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size):
latent_model_input, with self.progress_bar(total=num_inference_steps) as progress_bar:
t, for i, t in enumerate(timesteps):
encoder_hidden_states=prompt_embeds, # expand the latents if we are doing classifier free guidance
cross_attention_kwargs=cross_attention_kwargs, latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
).sample latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# perform guidance # predict the noise residual
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = self.unet(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latent_model_input,
t,
# perform self-attention guidance with the stored self-attentnion map encoder_hidden_states=prompt_embeds,
if do_self_attention_guidance: cross_attention_kwargs=cross_attention_kwargs,
# classifier-free guidance produces two chunks of attention map ).sample
# and we only use unconditional one according to equation (24)
# in https://arxiv.org/pdf/2210.00939.pdf # perform guidance
if do_classifier_free_guidance: if do_classifier_free_guidance:
# DDIM-like prediction of x0 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# get the stored attention maps
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) # perform self-attention guidance with the stored self-attentnion map
# self-attention-based degrading of latents if do_self_attention_guidance:
degraded_latents = self.sag_masking( # classifier-free guidance produces two chunks of attention map
pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) # and we only use unconditional one according to equation (24)
) # in https://arxiv.org/pdf/2210.00939.pdf
uncond_emb, _ = prompt_embeds.chunk(2) if do_classifier_free_guidance:
# forward and give guidance # DDIM-like prediction of x0
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) # get the stored attention maps
else: uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
# DDIM-like prediction of x0 # self-attention-based degrading of latents
pred_x0 = self.pred_x0(latents, noise_pred, t) degraded_latents = self.sag_masking(
# get the stored attention maps pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t)
cond_attn = store_processor.attention_probs )
# self-attention-based degrading of latents uncond_emb, _ = prompt_embeds.chunk(2)
degraded_latents = self.sag_masking( # forward and give guidance
pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=uncond_emb).sample
) noise_pred += sag_scale * (noise_pred_uncond - degraded_pred)
# forward and give guidance else:
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample # DDIM-like prediction of x0
noise_pred += sag_scale * (noise_pred - degraded_pred) pred_x0 = self.pred_x0(latents, noise_pred, t)
# get the stored attention maps
# compute the previous noisy sample x_t -> x_t-1 cond_attn = store_processor.attention_probs
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # self-attention-based degrading of latents
degraded_latents = self.sag_masking(
# call the callback, if provided pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): )
progress_bar.update() # forward and give guidance
if callback is not None and i % callback_steps == 0: degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
callback(i, t, latents) noise_pred += sag_scale * (noise_pred - degraded_pred)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing # 8. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents)
...@@ -680,20 +688,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -680,20 +688,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def sag_masking(self, original_latents, attn_map, t, eps): def sag_masking(self, original_latents, attn_map, map_size, t, eps):
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
bh, hw1, hw2 = attn_map.shape bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape b, latent_channel, latent_h, latent_w = original_latents.shape
h = self.unet.attention_head_dim h = self.unet.attention_head_dim
if isinstance(h, list): if isinstance(h, list):
h = h[-1] h = h[-1]
map_size = math.isqrt(hw1)
# Produce attention mask # Produce attention mask
attn_map = attn_map.reshape(b, h, hw1, hw2) attn_map = attn_map.reshape(b, h, hw1, hw2)
attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
attn_mask = ( attn_mask = (
attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) attn_mask.reshape(b, map_size[0], map_size[1])
.unsqueeze(1)
.repeat(1, latent_channel, 1, 1)
.type(attn_map.dtype)
) )
attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
......
...@@ -160,3 +160,25 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -160,3 +160,25 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371]) expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
def test_stable_diffusion_2_non_square(self):
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sag_pipe = sag_pipe.to(torch_device)
sag_pipe.set_progress_bar_config(disable=None)
prompt = "."
generator = torch.manual_seed(0)
output = sag_pipe(
[prompt],
width=768,
height=512,
generator=generator,
guidance_scale=7.5,
sag_scale=1.0,
num_inference_steps=20,
output_type="np",
)
image = output.images
assert image.shape == (1, 512, 768, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment