Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ba04a87d
Commit
ba04a87d
authored
Dec 13, 2023
by
comfyanonymous
Browse files
Refactor and improve the sag node.
Moved all the sag related code to comfy_extras/nodes_sag.py
parent
6761233e
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
331 additions
and
318 deletions
+331
-318
comfy/model_patcher.py
comfy/model_patcher.py
+13
-6
comfy/samplers.py
comfy/samplers.py
+238
-289
comfy_extras/nodes_sag.py
comfy_extras/nodes_sag.py
+80
-23
No files found.
comfy/model_patcher.py
View file @
ba04a87d
...
...
@@ -61,6 +61,9 @@ class ModelPatcher:
else
:
self
.
model_options
[
"sampler_cfg_function"
]
=
sampler_cfg_function
def
set_model_sampler_post_cfg_function
(
self
,
post_cfg_function
):
self
.
model_options
[
"sampler_post_cfg_function"
]
=
self
.
model_options
.
get
(
"sampler_post_cfg_function"
,
[])
+
[
post_cfg_function
]
def
set_model_unet_function_wrapper
(
self
,
unet_wrapper_function
):
self
.
model_options
[
"model_function_wrapper"
]
=
unet_wrapper_function
...
...
@@ -70,13 +73,17 @@ class ModelPatcher:
to
[
"patches"
]
=
{}
to
[
"patches"
][
name
]
=
to
[
"patches"
].
get
(
name
,
[])
+
[
patch
]
def
set_model_patch_replace
(
self
,
patch
,
name
,
block_name
,
number
):
def
set_model_patch_replace
(
self
,
patch
,
name
,
block_name
,
number
,
transformer_index
=
None
):
to
=
self
.
model_options
[
"transformer_options"
]
if
"patches_replace"
not
in
to
:
to
[
"patches_replace"
]
=
{}
if
name
not
in
to
[
"patches_replace"
]:
to
[
"patches_replace"
][
name
]
=
{}
to
[
"patches_replace"
][
name
][(
block_name
,
number
)]
=
patch
if
transformer_index
is
not
None
:
block
=
(
block_name
,
number
,
transformer_index
)
else
:
block
=
(
block_name
,
number
)
to
[
"patches_replace"
][
name
][
block
]
=
patch
def
set_model_attn1_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn1_patch"
)
...
...
@@ -84,11 +91,11 @@ class ModelPatcher:
def
set_model_attn2_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn2_patch"
)
def
set_model_attn1_replace
(
self
,
patch
,
block_name
,
number
):
self
.
set_model_patch_replace
(
patch
,
"attn1"
,
block_name
,
number
)
def
set_model_attn1_replace
(
self
,
patch
,
block_name
,
number
,
transformer_index
=
None
):
self
.
set_model_patch_replace
(
patch
,
"attn1"
,
block_name
,
number
,
transformer_index
)
def
set_model_attn2_replace
(
self
,
patch
,
block_name
,
number
):
self
.
set_model_patch_replace
(
patch
,
"attn2"
,
block_name
,
number
)
def
set_model_attn2_replace
(
self
,
patch
,
block_name
,
number
,
transformer_index
=
None
):
self
.
set_model_patch_replace
(
patch
,
"attn2"
,
block_name
,
number
,
transformer_index
)
def
set_model_attn1_output_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn1_output_patch"
)
...
...
comfy/samplers.py
View file @
ba04a87d
This diff is collapsed.
Click to expand it.
comfy_extras/nodes_sag.py
View file @
ba04a87d
import
torch
from
torch
import
einsum
import
torch.nn.functional
as
F
import
math
from
einops
import
rearrange
,
repeat
import
os
from
comfy.ldm.modules.attention
import
optimized_attention
,
_ATTN_PRECISION
import
comfy.samplers
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
...
...
@@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
)
return
(
out
,
sim
)
class
SagNode
:
def
create_blur_map
(
x0
,
attn
,
sigma
=
3.0
,
threshold
=
1.0
):
# reshape and GAP the attention map
_
,
hw1
,
hw2
=
attn
.
shape
b
,
_
,
lh
,
lw
=
x0
.
shape
attn
=
attn
.
reshape
(
b
,
-
1
,
hw1
,
hw2
)
# Global Average Pool
mask
=
attn
.
mean
(
1
,
keepdim
=
False
).
sum
(
1
,
keepdim
=
False
)
>
threshold
ratio
=
round
(
math
.
sqrt
(
lh
*
lw
/
hw1
))
mid_shape
=
[
math
.
ceil
(
lh
/
ratio
),
math
.
ceil
(
lw
/
ratio
)]
# Reshape
mask
=
(
mask
.
reshape
(
b
,
*
mid_shape
)
.
unsqueeze
(
1
)
.
type
(
attn
.
dtype
)
)
# Upsample
mask
=
F
.
interpolate
(
mask
,
(
lh
,
lw
))
blurred
=
gaussian_blur_2d
(
x0
,
kernel_size
=
9
,
sigma
=
sigma
)
blurred
=
blurred
*
mask
+
x0
*
(
1
-
mask
)
return
blurred
def
gaussian_blur_2d
(
img
,
kernel_size
,
sigma
):
ksize_half
=
(
kernel_size
-
1
)
*
0.5
x
=
torch
.
linspace
(
-
ksize_half
,
ksize_half
,
steps
=
kernel_size
)
pdf
=
torch
.
exp
(
-
0.5
*
(
x
/
sigma
).
pow
(
2
))
x_kernel
=
pdf
/
pdf
.
sum
()
x_kernel
=
x_kernel
.
to
(
device
=
img
.
device
,
dtype
=
img
.
dtype
)
kernel2d
=
torch
.
mm
(
x_kernel
[:,
None
],
x_kernel
[
None
,
:])
kernel2d
=
kernel2d
.
expand
(
img
.
shape
[
-
3
],
1
,
kernel2d
.
shape
[
0
],
kernel2d
.
shape
[
1
])
padding
=
[
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
]
img
=
F
.
pad
(
img
,
padding
,
mode
=
"reflect"
)
img
=
F
.
conv2d
(
img
,
kernel2d
,
groups
=
img
.
shape
[
-
3
])
return
img
class
SelfAttentionGuidance
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
...
...
@@ -63,15 +109,9 @@ class SagNode:
def
patch
(
self
,
model
,
scale
,
blur_sigma
):
m
=
model
.
clone
()
# set extra options on the model
m
.
model_options
[
"sag"
]
=
True
m
.
model_options
[
"sag_scale"
]
=
scale
m
.
model_options
[
"sag_sigma"
]
=
blur_sigma
attn_scores
=
None
mid_block_shape
=
None
m
.
model
.
get_attn_scores
=
lambda
:
attn_scores
m
.
model
.
get_mid_block_shape
=
lambda
:
mid_block_shape
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
...
...
@@ -92,24 +132,41 @@ class SagNode:
else
:
return
optimized_attention
(
q
,
k
,
v
,
heads
=
heads
)
def
post_cfg_function
(
args
):
nonlocal
attn_scores
nonlocal
mid_block_shape
uncond_attn
=
attn_scores
sag_scale
=
scale
sag_sigma
=
blur_sigma
sag_threshold
=
1.0
model
=
args
[
"model"
]
uncond_pred
=
args
[
"uncond_denoised"
]
uncond
=
args
[
"uncond"
]
cfg_result
=
args
[
"denoised"
]
sigma
=
args
[
"sigma"
]
model_options
=
args
[
"model_options"
]
x
=
args
[
"input"
]
# create the adversarially blurred image
degraded
=
create_blur_map
(
uncond_pred
,
uncond_attn
,
sag_sigma
,
sag_threshold
)
degraded_noised
=
degraded
+
x
-
uncond_pred
# call into the UNet
(
sag
,
_
)
=
comfy
.
samplers
.
calc_cond_uncond_batch
(
model
,
uncond
,
None
,
degraded_noised
,
sigma
,
model_options
)
return
cfg_result
+
(
degraded
-
sag
)
*
sag_scale
m
.
set_model_sampler_post_cfg_function
(
post_cfg_function
)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
def
set_model_patch_replace
(
patch
,
name
,
key
):
to
=
m
.
model_options
[
"transformer_options"
]
if
"patches_replace"
not
in
to
:
to
[
"patches_replace"
]
=
{}
if
name
not
in
to
[
"patches_replace"
]:
to
[
"patches_replace"
][
name
]
=
{}
to
[
"patches_replace"
][
name
][
key
]
=
patch
set_model_patch_replace
(
attn_and_record
,
"attn1"
,
(
"middle"
,
0
,
0
))
# from diffusers:
# unet.mid_block.attentions[0].register_forward_hook()
def
forward_hook
(
m
,
inp
,
out
):
nonlocal
mid_block_shape
mid_block_shape
=
out
[
0
].
shape
[
-
2
:]
m
.
model
.
diffusion_model
.
middle_block
[
0
].
register_forward_hook
(
forward_hook
)
m
.
set_model_attn1_replace
(
attn_and_record
,
"middle"
,
0
,
0
)
return
(
m
,
)
NODE_CLASS_MAPPINGS
=
{
"Self-Attention Guidance"
:
SagNode
,
"SelfAttentionGuidance"
:
SelfAttentionGuidance
,
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
"SelfAttentionGuidance"
:
"Self-Attention Guidance"
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment