Unverified Commit 3eaead0c authored by Joseph Coffland's avatar Joseph Coffland Committed by GitHub
Browse files

Allow SD attend and excite pipeline to work with any size output images (#2835)

Allow stable diffusion attend and excite pipeline to work with any size output image. Re: #2476, #2603
parent 3bf5ce21
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import inspect import inspect
import math import math
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -76,7 +76,7 @@ class AttentionStore: ...@@ -76,7 +76,7 @@ class AttentionStore:
def __call__(self, attn, is_cross: bool, place_in_unet: str): def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= 0 and is_cross: if self.cur_att_layer >= 0 and is_cross:
if attn.shape[1] == self.attn_res**2: if attn.shape[1] == np.prod(self.attn_res):
self.step_store[place_in_unet].append(attn) self.step_store[place_in_unet].append(attn)
self.cur_att_layer += 1 self.cur_att_layer += 1
...@@ -98,7 +98,7 @@ class AttentionStore: ...@@ -98,7 +98,7 @@ class AttentionStore:
attention_maps = self.get_average_attention() attention_maps = self.get_average_attention()
for location in from_where: for location in from_where:
for item in attention_maps[location]: for item in attention_maps[location]:
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1]) cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
out.append(cross_maps) out.append(cross_maps)
out = torch.cat(out, dim=0) out = torch.cat(out, dim=0)
out = out.sum(0) / out.shape[0] out = out.sum(0) / out.shape[0]
...@@ -109,7 +109,7 @@ class AttentionStore: ...@@ -109,7 +109,7 @@ class AttentionStore:
self.step_store = self.get_empty_store() self.step_store = self.get_empty_store()
self.attention_store = {} self.attention_store = {}
def __init__(self, attn_res=16): def __init__(self, attn_res):
""" """
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
process process
...@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
max_iter_to_alter: int = 25, max_iter_to_alter: int = 25,
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8}, thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
scale_factor: int = 20, scale_factor: int = 20,
attn_res: int = 16, attn_res: Optional[Tuple[int]] = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in. Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
scale_factor (`int`, *optional*, default to 20): scale_factor (`int`, *optional*, default to 20):
Scale factor that controls the step size of each Attend and Excite update. Scale factor that controls the step size of each Attend and Excite update.
attn_res (`int`, *optional*, default to 16): attn_res (`tuple`, *optional*, default computed from width and height):
The resolution of most semantic attention map. The 2D resolution of the semantic attention map.
Examples: Examples:
...@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
self.attention_store = AttentionStore(attn_res=attn_res) if attn_res is None:
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
self.attention_store = AttentionStore(attn_res)
self.register_attention_control() self.register_attention_control()
# default config for step size from original repo # default config for step size from original repo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment