"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "74468513bd4299197bbaaafcee7c8dfb48c3ddd7"
Unverified Commit 3eaead0c authored by Joseph Coffland's avatar Joseph Coffland Committed by GitHub
Browse files

Allow SD attend and excite pipeline to work with any size output images (#2835)

Allow stable diffusion attend and excite pipeline to work with any size output image. Re: #2476, #2603
parent 3bf5ce21
......@@ -14,7 +14,7 @@
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -76,7 +76,7 @@ class AttentionStore:
def __call__(self, attn, is_cross: bool, place_in_unet: str):
if self.cur_att_layer >= 0 and is_cross:
if attn.shape[1] == self.attn_res**2:
if attn.shape[1] == np.prod(self.attn_res):
self.step_store[place_in_unet].append(attn)
self.cur_att_layer += 1
......@@ -98,7 +98,7 @@ class AttentionStore:
attention_maps = self.get_average_attention()
for location in from_where:
for item in attention_maps[location]:
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
out.append(cross_maps)
out = torch.cat(out, dim=0)
out = out.sum(0) / out.shape[0]
......@@ -109,7 +109,7 @@ class AttentionStore:
self.step_store = self.get_empty_store()
self.attention_store = {}
def __init__(self, attn_res=16):
def __init__(self, attn_res):
"""
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
process
......@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
max_iter_to_alter: int = 25,
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
scale_factor: int = 20,
attn_res: int = 16,
attn_res: Optional[Tuple[int]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
......@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
scale_factor (`int`, *optional*, default to 20):
Scale factor that controls the step size of each Attend and Excite update.
attn_res (`int`, *optional*, default to 16):
The resolution of most semantic attention map.
attn_res (`tuple`, *optional*, default computed from width and height):
The 2D resolution of the semantic attention map.
Examples:
......@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
self.attention_store = AttentionStore(attn_res=attn_res)
if attn_res is None:
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
self.attention_store = AttentionStore(attn_res)
self.register_attention_control()
# default config for step size from original repo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment