Unverified Commit 6766a811 authored by Aki Sakurai's avatar Aki Sakurai Committed by GitHub
Browse files

Support non square image generation for StableDiffusionSAGPipeline (#2629)

* Support non square image generation for StableDiffusionSAGPipeline

* Fix style
parent bbab8553
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
...@@ -606,6 +605,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -606,6 +605,14 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
store_processor = CrossAttnStoreProcessor() store_processor = CrossAttnStoreProcessor()
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
map_size = None
def get_map_size(module, input, output):
nonlocal map_size
map_size = output.sample.shape[-2:]
with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size):
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
...@@ -613,6 +620,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -613,6 +620,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
latent_model_input, latent_model_input,
t, t,
...@@ -637,7 +645,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -637,7 +645,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) uncond_attn, cond_attn = store_processor.attention_probs.chunk(2)
# self-attention-based degrading of latents # self-attention-based degrading of latents
degraded_latents = self.sag_masking( degraded_latents = self.sag_masking(
pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) pred_x0, uncond_attn, map_size, t, self.pred_epsilon(latents, noise_pred_uncond, t)
) )
uncond_emb, _ = prompt_embeds.chunk(2) uncond_emb, _ = prompt_embeds.chunk(2)
# forward and give guidance # forward and give guidance
...@@ -650,7 +658,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -650,7 +658,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
cond_attn = store_processor.attention_probs cond_attn = store_processor.attention_probs
# self-attention-based degrading of latents # self-attention-based degrading of latents
degraded_latents = self.sag_masking( degraded_latents = self.sag_masking(
pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) pred_x0, cond_attn, map_size, t, self.pred_epsilon(latents, noise_pred, t)
) )
# forward and give guidance # forward and give guidance
degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample degraded_pred = self.unet(degraded_latents, t, encoder_hidden_states=prompt_embeds).sample
...@@ -680,20 +688,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline): ...@@ -680,20 +688,22 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def sag_masking(self, original_latents, attn_map, t, eps): def sag_masking(self, original_latents, attn_map, map_size, t, eps):
# Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
bh, hw1, hw2 = attn_map.shape bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape b, latent_channel, latent_h, latent_w = original_latents.shape
h = self.unet.attention_head_dim h = self.unet.attention_head_dim
if isinstance(h, list): if isinstance(h, list):
h = h[-1] h = h[-1]
map_size = math.isqrt(hw1)
# Produce attention mask # Produce attention mask
attn_map = attn_map.reshape(b, h, hw1, hw2) attn_map = attn_map.reshape(b, h, hw1, hw2)
attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0
attn_mask = ( attn_mask = (
attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) attn_mask.reshape(b, map_size[0], map_size[1])
.unsqueeze(1)
.repeat(1, latent_channel, 1, 1)
.type(attn_map.dtype)
) )
attn_mask = F.interpolate(attn_mask, (latent_h, latent_w)) attn_mask = F.interpolate(attn_mask, (latent_h, latent_w))
......
...@@ -160,3 +160,25 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -160,3 +160,25 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371]) expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
def test_stable_diffusion_2_non_square(self):
sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
sag_pipe = sag_pipe.to(torch_device)
sag_pipe.set_progress_bar_config(disable=None)
prompt = "."
generator = torch.manual_seed(0)
output = sag_pipe(
[prompt],
width=768,
height=512,
generator=generator,
guidance_scale=7.5,
sag_scale=1.0,
num_inference_steps=20,
output_type="np",
)
image = output.images
assert image.shape == (1, 512, 768, 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment