Unverified Commit 6766a811 authored by Aki Sakurai's avatar Aki Sakurai Committed by GitHub
Browse files

Support non square image generation for StableDiffusionSAGPipeline (#2629)

* Support non square image generation for StableDiffusionSAGPipeline

* Fix style
parent bbab8553
......@@ -13,7 +13,6 @@
# limitations under the License.
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Union
import torch
......@@ -606,6 +605,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
store_processor = CrossAttnStoreProcessor()
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
map_size = None
def get_map_size(module, input, output):
nonlocal map_size
map_size = output.sample.shape[-2:]
with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
......@@ -613,6 +620,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
......@@ -637,7 +645,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
# self-attention-based degrading of latents
degraded_latents = self.sag_masking(
pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t)
pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t)
)
uncond_emb, _ = prompt_embeds.chunk(2)
# forward and give guidance
......@@ -650,7 +658,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
cond_attn = store_processor.attention_probs
# self-attention-based degrading of latents
degraded_latents = self.sag_masking(
pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t)
pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t)
)
# forward and give guidance
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
......@@ -680,20 +688,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def sag_masking(self, original_latents, attn_map, t, eps):
def sag_masking(self, original_latents, attn_map, map_size, t, eps):
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape
h = self.unet.attention_head_dim
if isinstance(h, list):
h = h[-1]
map_size = math.isqrt(hw1)
# Produce attention mask
attn_map = attn_map.reshape(b, h, hw1, hw2)
attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
attn_mask = (
attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype)
attn_mask.reshape(b, map_size[0], map_size[1])
.unsqueeze(1)
.repeat(1, latent_channel, 1, 1)
.type(attn_map.dtype)
)
attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
......
......@@ -160,3 +160,25 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
def test_stable_diffusion_2_non_square(self):
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sag_pipe = sag_pipe.to(torch_device)
sag_pipe.set_progress_bar_config(disable=None)
prompt = "."
generator = torch.manual_seed(0)
output = sag_pipe(
[prompt],
width=768,
height=512,
generator=generator,
guidance_scale=7.5,
sag_scale=1.0,
num_inference_steps=20,
output_type="np",
)
image = output.images
assert image.shape == (1, 512, 768, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment