You need to sign in or sign up before continuing.
Commit ba04a87d authored by comfyanonymous's avatar comfyanonymous
Browse files

Refactor and improve the sag node.

Moved all the sag related code to comfy_extras/nodes_sag.py
parent 6761233e
...@@ -61,6 +61,9 @@ class ModelPatcher: ...@@ -61,6 +61,9 @@ class ModelPatcher:
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
def set_model_sampler_post_cfg_function(self, post_cfg_function):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
def set_model_unet_function_wrapper(self, unet_wrapper_function): def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function self.model_options["model_function_wrapper"] = unet_wrapper_function
...@@ -70,13 +73,17 @@ class ModelPatcher: ...@@ -70,13 +73,17 @@ class ModelPatcher:
to["patches"] = {} to["patches"] = {}
to["patches"][name] = to["patches"].get(name, []) + [patch] to["patches"][name] = to["patches"].get(name, []) + [patch]
def set_model_patch_replace(self, patch, name, block_name, number): def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches_replace" not in to: if "patches_replace" not in to:
to["patches_replace"] = {} to["patches_replace"] = {}
if name not in to["patches_replace"]: if name not in to["patches_replace"]:
to["patches_replace"][name] = {} to["patches_replace"][name] = {}
to["patches_replace"][name][(block_name, number)] = patch if transformer_index is not None:
block = (block_name, number, transformer_index)
else:
block = (block_name, number)
to["patches_replace"][name][block] = patch
def set_model_attn1_patch(self, patch): def set_model_attn1_patch(self, patch):
self.set_model_patch(patch, "attn1_patch") self.set_model_patch(patch, "attn1_patch")
...@@ -84,11 +91,11 @@ class ModelPatcher: ...@@ -84,11 +91,11 @@ class ModelPatcher:
def set_model_attn2_patch(self, patch): def set_model_attn2_patch(self, patch):
self.set_model_patch(patch, "attn2_patch") self.set_model_patch(patch, "attn2_patch")
def set_model_attn1_replace(self, patch, block_name, number): def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn1", block_name, number) self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
def set_model_attn2_replace(self, patch, block_name, number): def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
self.set_model_patch_replace(patch, "attn2", block_name, number) self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
def set_model_attn1_output_patch(self, patch): def set_model_attn1_output_patch(self, patch):
self.set_model_patch(patch, "attn1_output_patch") self.set_model_patch(patch, "attn1_output_patch")
......
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import torch.nn.functional as F
import enum import enum
from comfy import model_management from comfy import model_management
import math import math
...@@ -9,11 +8,7 @@ from comfy import model_base ...@@ -9,11 +8,7 @@ from comfy import model_base
import comfy.utils import comfy.utils
import comfy.conds import comfy.conds
def get_area_and_mult(conds, x_in, timestep_in):
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0 strength = 1.0
...@@ -85,7 +80,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option ...@@ -85,7 +80,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return (input_x, mult, conditioning, area, control, patches) return (input_x, mult, conditioning, area, control, patches)
def cond_equal_size(c1, c2): def cond_equal_size(c1, c2):
if c1 is c2: if c1 is c2:
return True return True
if c1.keys() != c2.keys(): if c1.keys() != c2.keys():
...@@ -95,7 +90,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option ...@@ -95,7 +90,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return False return False
return True return True
def can_concat_cond(c1, c2): def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape: if c1[0].shape != c2[0].shape:
return False return False
...@@ -115,7 +110,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option ...@@ -115,7 +110,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return cond_equal_size(c1[2], c2[2]) return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list): def cond_cat(c_list):
c_crossattn = [] c_crossattn = []
c_concat = [] c_concat = []
c_adm = [] c_adm = []
...@@ -135,7 +130,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option ...@@ -135,7 +130,7 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
return out return out
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37 out_count = torch.ones_like(x_in) * 1e-37
...@@ -246,72 +241,26 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option ...@@ -246,72 +241,26 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
del out_uncond_count del out_uncond_count
return out_cond, out_uncond return out_cond, out_uncond
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0):
uncond_ = None
else:
uncond_ = uncond
# if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out. cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
if math.isclose(cond_scale, 1.0) and "sag" not in model_options:
uncond = None
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options)
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
cfg_result = x - model_options["sampler_cfg_function"](args) cfg_result = x - model_options["sampler_cfg_function"](args)
if "sag" in model_options: for fn in model_options.get("sampler_post_cfg_function", []):
assert uncond is not None, "SAG requires uncond guidance" args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
sag_scale = model_options["sag_scale"] "sigma": timestep, "model_options": model_options, "input": x}
sag_sigma = model_options["sag_sigma"] cfg_result = fn(args)
sag_threshold = model_options.get("sag_threshold", 1.0)
# these methods are added by the sag patcher
uncond_attn = model.get_attn_scores()
mid_shape = model.get_mid_block_shape()
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, mid_shape, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = calc_cond_uncond_batch(model, uncond, None, degraded_noised, timestep, model_options)
cfg_result += (degraded - sag) * sag_scale
return cfg_result
def create_blur_map(x0, attn, mid_shape, sigma=3.0, threshold=1.0): return cfg_result
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class CFGNoisePredictor(torch.nn.Module): class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
......
import torch import torch
from torch import einsum from torch import einsum
import torch.nn.functional as F
import math
from einops import rearrange, repeat from einops import rearrange, repeat
import os import os
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
import comfy.samplers
# from comfy/ldm/modules/attention.py # from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output # but modified to return attention scores as well as output
...@@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): ...@@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
) )
return (out, sim) return (out, sim)
class SagNode: def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
b, _, lh, lw = x0.shape
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = round(math.sqrt(lh * lw / hw1))
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
mask = (
mask.reshape(b, *mid_shape)
.unsqueeze(1)
.type(attn.dtype)
)
# Upsample
mask = F.interpolate(mask, (lh, lw))
blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
x_kernel = pdf / pdf.sum()
x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
img = F.pad(img, padding, mode="reflect")
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
...@@ -63,15 +109,9 @@ class SagNode: ...@@ -63,15 +109,9 @@ class SagNode:
def patch(self, model, scale, blur_sigma): def patch(self, model, scale, blur_sigma):
m = model.clone() m = model.clone()
# set extra options on the model
m.model_options["sag"] = True
m.model_options["sag_scale"] = scale
m.model_options["sag_sigma"] = blur_sigma
attn_scores = None attn_scores = None
mid_block_shape = None mid_block_shape = None
m.model.get_attn_scores = lambda: attn_scores
m.model.get_mid_block_shape = lambda: mid_block_shape
# TODO: make this work properly with chunked batches # TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call # currently, we can only save the attn from one UNet call
...@@ -92,24 +132,41 @@ class SagNode: ...@@ -92,24 +132,41 @@ class SagNode:
else: else:
return optimized_attention(q, k, v, heads=heads) return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
nonlocal mid_block_shape
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function)
# from diffusers: # from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
def set_model_patch_replace(patch, name, key): m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
to = m.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if name not in to["patches_replace"]:
to["patches_replace"][name] = {}
to["patches_replace"][name][key] = patch
set_model_patch_replace(attn_and_record, "attn1", ("middle", 0, 0))
# from diffusers:
# unet.mid_block.attentions[0].register_forward_hook()
def forward_hook(m, inp, out):
nonlocal mid_block_shape
mid_block_shape = out[0].shape[-2:]
m.model.diffusion_model.middle_block[0].register_forward_hook(forward_hook)
return (m, ) return (m, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"Self-Attention Guidance": SagNode, "SelfAttentionGuidance": SelfAttentionGuidance,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SelfAttentionGuidance": "Self-Attention Guidance",
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment