Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ba04a87d
Commit
ba04a87d
authored
Dec 13, 2023
by
comfyanonymous
Browse files
Refactor and improve the sag node.
Moved all the sag related code to comfy_extras/nodes_sag.py
parent
6761233e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
331 additions
and
318 deletions
+331
-318
comfy/model_patcher.py
comfy/model_patcher.py
+13
-6
comfy/samplers.py
comfy/samplers.py
+238
-289
comfy_extras/nodes_sag.py
comfy_extras/nodes_sag.py
+80
-23
No files found.
comfy/model_patcher.py
View file @
ba04a87d
...
@@ -61,6 +61,9 @@ class ModelPatcher:
...
@@ -61,6 +61,9 @@ class ModelPatcher:
else
:
else
:
self
.
model_options
[
"sampler_cfg_function"
]
=
sampler_cfg_function
self
.
model_options
[
"sampler_cfg_function"
]
=
sampler_cfg_function
def
set_model_sampler_post_cfg_function
(
self
,
post_cfg_function
):
self
.
model_options
[
"sampler_post_cfg_function"
]
=
self
.
model_options
.
get
(
"sampler_post_cfg_function"
,
[])
+
[
post_cfg_function
]
def
set_model_unet_function_wrapper
(
self
,
unet_wrapper_function
):
def
set_model_unet_function_wrapper
(
self
,
unet_wrapper_function
):
self
.
model_options
[
"model_function_wrapper"
]
=
unet_wrapper_function
self
.
model_options
[
"model_function_wrapper"
]
=
unet_wrapper_function
...
@@ -70,13 +73,17 @@ class ModelPatcher:
...
@@ -70,13 +73,17 @@ class ModelPatcher:
to
[
"patches"
]
=
{}
to
[
"patches"
]
=
{}
to
[
"patches"
][
name
]
=
to
[
"patches"
].
get
(
name
,
[])
+
[
patch
]
to
[
"patches"
][
name
]
=
to
[
"patches"
].
get
(
name
,
[])
+
[
patch
]
def
set_model_patch_replace
(
self
,
patch
,
name
,
block_name
,
number
):
def
set_model_patch_replace
(
self
,
patch
,
name
,
block_name
,
number
,
transformer_index
=
None
):
to
=
self
.
model_options
[
"transformer_options"
]
to
=
self
.
model_options
[
"transformer_options"
]
if
"patches_replace"
not
in
to
:
if
"patches_replace"
not
in
to
:
to
[
"patches_replace"
]
=
{}
to
[
"patches_replace"
]
=
{}
if
name
not
in
to
[
"patches_replace"
]:
if
name
not
in
to
[
"patches_replace"
]:
to
[
"patches_replace"
][
name
]
=
{}
to
[
"patches_replace"
][
name
]
=
{}
to
[
"patches_replace"
][
name
][(
block_name
,
number
)]
=
patch
if
transformer_index
is
not
None
:
block
=
(
block_name
,
number
,
transformer_index
)
else
:
block
=
(
block_name
,
number
)
to
[
"patches_replace"
][
name
][
block
]
=
patch
def
set_model_attn1_patch
(
self
,
patch
):
def
set_model_attn1_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn1_patch"
)
self
.
set_model_patch
(
patch
,
"attn1_patch"
)
...
@@ -84,11 +91,11 @@ class ModelPatcher:
...
@@ -84,11 +91,11 @@ class ModelPatcher:
def
set_model_attn2_patch
(
self
,
patch
):
def
set_model_attn2_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn2_patch"
)
self
.
set_model_patch
(
patch
,
"attn2_patch"
)
def
set_model_attn1_replace
(
self
,
patch
,
block_name
,
number
):
def
set_model_attn1_replace
(
self
,
patch
,
block_name
,
number
,
transformer_index
=
None
):
self
.
set_model_patch_replace
(
patch
,
"attn1"
,
block_name
,
number
)
self
.
set_model_patch_replace
(
patch
,
"attn1"
,
block_name
,
number
,
transformer_index
)
def
set_model_attn2_replace
(
self
,
patch
,
block_name
,
number
):
def
set_model_attn2_replace
(
self
,
patch
,
block_name
,
number
,
transformer_index
=
None
):
self
.
set_model_patch_replace
(
patch
,
"attn2"
,
block_name
,
number
)
self
.
set_model_patch_replace
(
patch
,
"attn2"
,
block_name
,
number
,
transformer_index
)
def
set_model_attn1_output_patch
(
self
,
patch
):
def
set_model_attn1_output_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn1_output_patch"
)
self
.
set_model_patch
(
patch
,
"attn1_output_patch"
)
...
...
comfy/samplers.py
View file @
ba04a87d
from
.k_diffusion
import
sampling
as
k_diffusion_sampling
from
.k_diffusion
import
sampling
as
k_diffusion_sampling
from
.extra_samplers
import
uni_pc
from
.extra_samplers
import
uni_pc
import
torch
import
torch
import
torch.nn.functional
as
F
import
enum
import
enum
from
comfy
import
model_management
from
comfy
import
model_management
import
math
import
math
...
@@ -9,309 +8,259 @@ from comfy import model_base
...
@@ -9,309 +8,259 @@ from comfy import model_base
import
comfy.utils
import
comfy.utils
import
comfy.conds
import
comfy.conds
def
get_area_and_mult
(
conds
,
x_in
,
timestep_in
):
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
strength
=
1.0
if
'timestep_start'
in
conds
:
timestep_start
=
conds
[
'timestep_start'
]
if
timestep_in
[
0
]
>
timestep_start
:
return
None
if
'timestep_end'
in
conds
:
timestep_end
=
conds
[
'timestep_end'
]
if
timestep_in
[
0
]
<
timestep_end
:
return
None
if
'area'
in
conds
:
area
=
conds
[
'area'
]
if
'strength'
in
conds
:
strength
=
conds
[
'strength'
]
input_x
=
x_in
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
if
'mask'
in
conds
:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength
=
1.0
if
"mask_strength"
in
conds
:
mask_strength
=
conds
[
"mask_strength"
]
mask
=
conds
[
'mask'
]
assert
(
mask
.
shape
[
1
]
==
x_in
.
shape
[
2
])
assert
(
mask
.
shape
[
2
]
==
x_in
.
shape
[
3
])
mask
=
mask
[:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
*
mask_strength
mask
=
mask
.
unsqueeze
(
1
).
repeat
(
input_x
.
shape
[
0
]
//
mask
.
shape
[
0
],
input_x
.
shape
[
1
],
1
,
1
)
else
:
mask
=
torch
.
ones_like
(
input_x
)
mult
=
mask
*
strength
if
'mask'
not
in
conds
:
rr
=
8
if
area
[
2
]
!=
0
:
for
t
in
range
(
rr
):
mult
[:,:,
t
:
1
+
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
(
area
[
0
]
+
area
[
2
])
<
x_in
.
shape
[
2
]:
for
t
in
range
(
rr
):
mult
[:,:,
area
[
0
]
-
1
-
t
:
area
[
0
]
-
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
area
[
3
]
!=
0
:
for
t
in
range
(
rr
):
mult
[:,:,:,
t
:
1
+
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
for
t
in
range
(
rr
):
mult
[:,:,:,
area
[
1
]
-
1
-
t
:
area
[
1
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
conditioning
=
{}
model_conds
=
conds
[
"model_conds"
]
for
c
in
model_conds
:
conditioning
[
c
]
=
model_conds
[
c
].
process_cond
(
batch_size
=
x_in
.
shape
[
0
],
device
=
x_in
.
device
,
area
=
area
)
control
=
None
if
'control'
in
conds
:
control
=
conds
[
'control'
]
patches
=
None
if
'gligen'
in
conds
:
gligen
=
conds
[
'gligen'
]
patches
=
{}
gligen_type
=
gligen
[
0
]
gligen_model
=
gligen
[
1
]
if
gligen_type
==
"position"
:
gligen_patch
=
gligen_model
.
model
.
set_position
(
input_x
.
shape
,
gligen
[
2
],
input_x
.
device
)
else
:
gligen_patch
=
gligen_model
.
model
.
set_empty
(
input_x
.
shape
,
input_x
.
device
)
patches
[
'middle_patch'
]
=
[
gligen_patch
]
return
(
input_x
,
mult
,
conditioning
,
area
,
control
,
patches
)
def
cond_equal_size
(
c1
,
c2
):
if
c1
is
c2
:
return
True
if
c1
.
keys
()
!=
c2
.
keys
():
return
False
for
k
in
c1
:
if
not
c1
[
k
].
can_concat
(
c2
[
k
]):
return
False
return
True
def
can_concat_cond
(
c1
,
c2
):
if
c1
[
0
].
shape
!=
c2
[
0
].
shape
:
return
False
#control
if
(
c1
[
4
]
is
None
)
!=
(
c2
[
4
]
is
None
):
return
False
if
c1
[
4
]
is
not
None
:
if
c1
[
4
]
is
not
c2
[
4
]:
return
False
#patches
if
(
c1
[
5
]
is
None
)
!=
(
c2
[
5
]
is
None
):
return
False
if
(
c1
[
5
]
is
not
None
):
if
c1
[
5
]
is
not
c2
[
5
]:
return
False
return
cond_equal_size
(
c1
[
2
],
c2
[
2
])
def
cond_cat
(
c_list
):
c_crossattn
=
[]
c_concat
=
[]
c_adm
=
[]
crossattn_max_len
=
0
temp
=
{}
for
x
in
c_list
:
for
k
in
x
:
cur
=
temp
.
get
(
k
,
[])
cur
.
append
(
x
[
k
])
temp
[
k
]
=
cur
out
=
{}
for
k
in
temp
:
conds
=
temp
[
k
]
out
[
k
]
=
conds
[
0
].
concat
(
conds
[
1
:])
return
out
def
calc_cond_uncond_batch
(
model
,
cond
,
uncond
,
x_in
,
timestep
,
model_options
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
*
1e-37
out_uncond
=
torch
.
zeros_like
(
x_in
)
out_uncond_count
=
torch
.
ones_like
(
x_in
)
*
1e-37
COND
=
0
UNCOND
=
1
to_run
=
[]
for
x
in
cond
:
p
=
get_area_and_mult
(
x
,
x_in
,
timestep
)
if
p
is
None
:
continue
#The main sampling function shared by all the samplers
to_run
+=
[(
p
,
COND
)]
#Returns denoised
if
uncond
is
not
None
:
def
sampling_function
(
model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
for
x
in
uncond
:
def
get_area_and_mult
(
conds
,
x_in
,
timestep_in
):
p
=
get_area_and_mult
(
x
,
x_in
,
timestep
)
area
=
(
x_in
.
shape
[
2
],
x_in
.
shape
[
3
],
0
,
0
)
if
p
is
None
:
strength
=
1.0
continue
if
'timestep_start'
in
conds
:
to_run
+=
[(
p
,
UNCOND
)]
timestep_start
=
conds
[
'timestep_start'
]
if
timestep_in
[
0
]
>
timestep_start
:
while
len
(
to_run
)
>
0
:
return
None
first
=
to_run
[
0
]
if
'timestep_end'
in
conds
:
first_shape
=
first
[
0
][
0
].
shape
timestep_end
=
conds
[
'timestep_end'
]
to_batch_temp
=
[]
if
timestep_in
[
0
]
<
timestep_end
:
for
x
in
range
(
len
(
to_run
)):
return
None
if
can_concat_cond
(
to_run
[
x
][
0
],
first
[
0
]):
if
'area'
in
conds
:
to_batch_temp
+=
[
x
]
area
=
conds
[
'area'
]
if
'strength'
in
conds
:
to_batch_temp
.
reverse
()
strength
=
conds
[
'strength'
]
to_batch
=
to_batch_temp
[:
1
]
input_x
=
x_in
[:,:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
free_memory
=
model_management
.
get_free_memory
(
x_in
.
device
)
if
'mask'
in
conds
:
for
i
in
range
(
1
,
len
(
to_batch_temp
)
+
1
):
# Scale the mask to the size of the input
batch_amount
=
to_batch_temp
[:
len
(
to_batch_temp
)
//
i
]
# The mask should have been resized as we began the sampling process
input_shape
=
[
len
(
batch_amount
)
*
first_shape
[
0
]]
+
list
(
first_shape
)[
1
:]
mask_strength
=
1.0
if
model
.
memory_required
(
input_shape
)
<
free_memory
:
if
"mask_strength"
in
conds
:
to_batch
=
batch_amount
mask_strength
=
conds
[
"mask_strength"
]
break
mask
=
conds
[
'mask'
]
assert
(
mask
.
shape
[
1
]
==
x_in
.
shape
[
2
])
input_x
=
[]
assert
(
mask
.
shape
[
2
]
==
x_in
.
shape
[
3
])
mult
=
[]
mask
=
mask
[:,
area
[
2
]:
area
[
0
]
+
area
[
2
],
area
[
3
]:
area
[
1
]
+
area
[
3
]]
*
mask_strength
c
=
[]
mask
=
mask
.
unsqueeze
(
1
).
repeat
(
input_x
.
shape
[
0
]
//
mask
.
shape
[
0
],
input_x
.
shape
[
1
],
1
,
1
)
cond_or_uncond
=
[]
else
:
area
=
[]
mask
=
torch
.
ones_like
(
input_x
)
control
=
None
mult
=
mask
*
strength
patches
=
None
for
x
in
to_batch
:
if
'mask'
not
in
conds
:
o
=
to_run
.
pop
(
x
)
rr
=
8
p
=
o
[
0
]
if
area
[
2
]
!=
0
:
input_x
+=
[
p
[
0
]]
for
t
in
range
(
rr
):
mult
+=
[
p
[
1
]]
mult
[:,:,
t
:
1
+
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
c
+=
[
p
[
2
]]
if
(
area
[
0
]
+
area
[
2
])
<
x_in
.
shape
[
2
]:
area
+=
[
p
[
3
]]
for
t
in
range
(
rr
):
cond_or_uncond
+=
[
o
[
1
]]
mult
[:,:,
area
[
0
]
-
1
-
t
:
area
[
0
]
-
t
,:]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
control
=
p
[
4
]
if
area
[
3
]
!=
0
:
patches
=
p
[
5
]
for
t
in
range
(
rr
):
mult
[:,:,:,
t
:
1
+
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
batch_chunks
=
len
(
cond_or_uncond
)
if
(
area
[
1
]
+
area
[
3
])
<
x_in
.
shape
[
3
]:
input_x
=
torch
.
cat
(
input_x
)
for
t
in
range
(
rr
):
c
=
cond_cat
(
c
)
mult
[:,:,:,
area
[
1
]
-
1
-
t
:
area
[
1
]
-
t
]
*=
((
1.0
/
rr
)
*
(
t
+
1
))
timestep_
=
torch
.
cat
([
timestep
]
*
batch_chunks
)
conditioning
=
{}
if
control
is
not
None
:
model_conds
=
conds
[
"model_conds"
]
c
[
'control'
]
=
control
.
get_control
(
input_x
,
timestep_
,
c
,
len
(
cond_or_uncond
))
for
c
in
model_conds
:
conditioning
[
c
]
=
model_conds
[
c
].
process_cond
(
batch_size
=
x_in
.
shape
[
0
],
device
=
x_in
.
device
,
area
=
area
)
transformer_options
=
{}
if
'transformer_options'
in
model_options
:
control
=
None
transformer_options
=
model_options
[
'transformer_options'
].
copy
()
if
'control'
in
conds
:
control
=
conds
[
'control'
]
if
patches
is
not
None
:
if
"patches"
in
transformer_options
:
patches
=
None
cur_patches
=
transformer_options
[
"patches"
].
copy
()
if
'gligen'
in
conds
:
for
p
in
patches
:
gligen
=
conds
[
'gligen'
]
if
p
in
cur_patches
:
patches
=
{}
cur_patches
[
p
]
=
cur_patches
[
p
]
+
patches
[
p
]
gligen_type
=
gligen
[
0
]
gligen_model
=
gligen
[
1
]
if
gligen_type
==
"position"
:
gligen_patch
=
gligen_model
.
model
.
set_position
(
input_x
.
shape
,
gligen
[
2
],
input_x
.
device
)
else
:
gligen_patch
=
gligen_model
.
model
.
set_empty
(
input_x
.
shape
,
input_x
.
device
)
patches
[
'middle_patch'
]
=
[
gligen_patch
]
return
(
input_x
,
mult
,
conditioning
,
area
,
control
,
patches
)
def
cond_equal_size
(
c1
,
c2
):
if
c1
is
c2
:
return
True
if
c1
.
keys
()
!=
c2
.
keys
():
return
False
for
k
in
c1
:
if
not
c1
[
k
].
can_concat
(
c2
[
k
]):
return
False
return
True
def
can_concat_cond
(
c1
,
c2
):
if
c1
[
0
].
shape
!=
c2
[
0
].
shape
:
return
False
#control
if
(
c1
[
4
]
is
None
)
!=
(
c2
[
4
]
is
None
):
return
False
if
c1
[
4
]
is
not
None
:
if
c1
[
4
]
is
not
c2
[
4
]:
return
False
#patches
if
(
c1
[
5
]
is
None
)
!=
(
c2
[
5
]
is
None
):
return
False
if
(
c1
[
5
]
is
not
None
):
if
c1
[
5
]
is
not
c2
[
5
]:
return
False
return
cond_equal_size
(
c1
[
2
],
c2
[
2
])
def
cond_cat
(
c_list
):
c_crossattn
=
[]
c_concat
=
[]
c_adm
=
[]
crossattn_max_len
=
0
temp
=
{}
for
x
in
c_list
:
for
k
in
x
:
cur
=
temp
.
get
(
k
,
[])
cur
.
append
(
x
[
k
])
temp
[
k
]
=
cur
out
=
{}
for
k
in
temp
:
conds
=
temp
[
k
]
out
[
k
]
=
conds
[
0
].
concat
(
conds
[
1
:])
return
out
def
calc_cond_uncond_batch
(
model
,
cond
,
uncond
,
x_in
,
timestep
,
model_options
):
out_cond
=
torch
.
zeros_like
(
x_in
)
out_count
=
torch
.
ones_like
(
x_in
)
*
1e-37
out_uncond
=
torch
.
zeros_like
(
x_in
)
out_uncond_count
=
torch
.
ones_like
(
x_in
)
*
1e-37
COND
=
0
UNCOND
=
1
to_run
=
[]
for
x
in
cond
:
p
=
get_area_and_mult
(
x
,
x_in
,
timestep
)
if
p
is
None
:
continue
to_run
+=
[(
p
,
COND
)]
if
uncond
is
not
None
:
for
x
in
uncond
:
p
=
get_area_and_mult
(
x
,
x_in
,
timestep
)
if
p
is
None
:
continue
to_run
+=
[(
p
,
UNCOND
)]
while
len
(
to_run
)
>
0
:
first
=
to_run
[
0
]
first_shape
=
first
[
0
][
0
].
shape
to_batch_temp
=
[]
for
x
in
range
(
len
(
to_run
)):
if
can_concat_cond
(
to_run
[
x
][
0
],
first
[
0
]):
to_batch_temp
+=
[
x
]
to_batch_temp
.
reverse
()
to_batch
=
to_batch_temp
[:
1
]
free_memory
=
model_management
.
get_free_memory
(
x_in
.
device
)
for
i
in
range
(
1
,
len
(
to_batch_temp
)
+
1
):
batch_amount
=
to_batch_temp
[:
len
(
to_batch_temp
)
//
i
]
input_shape
=
[
len
(
batch_amount
)
*
first_shape
[
0
]]
+
list
(
first_shape
)[
1
:]
if
model
.
memory_required
(
input_shape
)
<
free_memory
:
to_batch
=
batch_amount
break
input_x
=
[]
mult
=
[]
c
=
[]
cond_or_uncond
=
[]
area
=
[]
control
=
None
patches
=
None
for
x
in
to_batch
:
o
=
to_run
.
pop
(
x
)
p
=
o
[
0
]
input_x
+=
[
p
[
0
]]
mult
+=
[
p
[
1
]]
c
+=
[
p
[
2
]]
area
+=
[
p
[
3
]]
cond_or_uncond
+=
[
o
[
1
]]
control
=
p
[
4
]
patches
=
p
[
5
]
batch_chunks
=
len
(
cond_or_uncond
)
input_x
=
torch
.
cat
(
input_x
)
c
=
cond_cat
(
c
)
timestep_
=
torch
.
cat
([
timestep
]
*
batch_chunks
)
if
control
is
not
None
:
c
[
'control'
]
=
control
.
get_control
(
input_x
,
timestep_
,
c
,
len
(
cond_or_uncond
))
transformer_options
=
{}
if
'transformer_options'
in
model_options
:
transformer_options
=
model_options
[
'transformer_options'
].
copy
()
if
patches
is
not
None
:
if
"patches"
in
transformer_options
:
cur_patches
=
transformer_options
[
"patches"
].
copy
()
for
p
in
patches
:
if
p
in
cur_patches
:
cur_patches
[
p
]
=
cur_patches
[
p
]
+
patches
[
p
]
else
:
cur_patches
[
p
]
=
patches
[
p
]
else
:
else
:
transformer_options
[
"patches"
]
=
patches
cur_patches
[
p
]
=
patches
[
p
]
else
:
transformer_options
[
"cond_or_uncond"
]
=
cond_or_uncond
[:]
transformer_options
[
"patches"
]
=
patches
transformer_options
[
"sigmas"
]
=
timestep
c
[
'transformer_options'
]
=
transformer_options
transformer_options
[
"cond_or_uncond"
]
=
cond_or_uncond
[:]
transformer_options
[
"sigmas"
]
=
timestep
if
'model_function_wrapper'
in
model_options
:
c
[
'transformer_options'
]
=
transformer_options
output
=
model_options
[
'model_function_wrapper'
](
model
.
apply_model
,
{
"input"
:
input_x
,
"timestep"
:
timestep_
,
"c"
:
c
,
"cond_or_uncond"
:
cond_or_uncond
}).
chunk
(
batch_chunks
)
else
:
output
=
model
.
apply_model
(
input_x
,
timestep_
,
**
c
).
chunk
(
batch_chunks
)
del
input_x
for
o
in
range
(
batch_chunks
):
if
'model_function_wrapper'
in
model_options
:
if
cond_or_uncond
[
o
]
==
COND
:
output
=
model_options
[
'model_function_wrapper'
](
model
.
apply_model
,
{
"input"
:
input_x
,
"timestep"
:
timestep_
,
"c"
:
c
,
"cond_or_uncond"
:
cond_or_uncond
}).
chunk
(
batch_chunks
)
out_cond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
else
:
out_count
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
mult
[
o
]
output
=
model
.
apply_model
(
input_x
,
timestep_
,
**
c
).
chunk
(
batch_chunks
)
else
:
del
input_x
out_uncond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
out_uncond_count
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
mult
[
o
]
del
mult
out_cond
/=
out_count
for
o
in
range
(
batch_chunks
):
del
out_count
if
cond_or_uncond
[
o
]
==
COND
:
out_uncond
/=
out_uncond_count
out_cond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
del
out_uncond_count
out_count
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
mult
[
o
]
return
out_cond
,
out_uncond
else
:
out_uncond
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
output
[
o
]
*
mult
[
o
]
out_uncond_count
[:,:,
area
[
o
][
2
]:
area
[
o
][
0
]
+
area
[
o
][
2
],
area
[
o
][
3
]:
area
[
o
][
1
]
+
area
[
o
][
3
]]
+=
mult
[
o
]
del
mult
out_cond
/=
out_count
del
out_count
out_uncond
/=
out_uncond_count
del
out_uncond_count
return
out_cond
,
out_uncond
# if we're doing SAG, we still need to do uncond guidance, even though the cond and uncond will cancel out.
#The main sampling function shared by all the samplers
if
math
.
isclose
(
cond_scale
,
1.0
)
and
"sag"
not
in
model_options
:
#Returns denoised
uncond
=
None
def
sampling_function
(
model
,
x
,
timestep
,
uncond
,
cond
,
cond_scale
,
model_options
=
{},
seed
=
None
):
if
math
.
isclose
(
cond_scale
,
1.0
):
uncond_
=
None
else
:
uncond_
=
uncond
cond_pred
,
uncond_pred
=
calc_cond_uncond_batch
(
model
,
cond
,
uncond
,
x
,
timestep
,
model_options
)
cond_pred
,
uncond_pred
=
calc_cond_uncond_batch
(
model
,
cond
,
uncond
_
,
x
,
timestep
,
model_options
)
cfg_result
=
uncond_pred
+
(
cond_pred
-
uncond_pred
)
*
cond_scale
cfg_result
=
uncond_pred
+
(
cond_pred
-
uncond_pred
)
*
cond_scale
if
"sampler_cfg_function"
in
model_options
:
if
"sampler_cfg_function"
in
model_options
:
args
=
{
"cond"
:
x
-
cond_pred
,
"uncond"
:
x
-
uncond_pred
,
"cond_scale"
:
cond_scale
,
"timestep"
:
timestep
,
"input"
:
x
,
"sigma"
:
timestep
}
args
=
{
"cond"
:
x
-
cond_pred
,
"uncond"
:
x
-
uncond_pred
,
"cond_scale"
:
cond_scale
,
"timestep"
:
timestep
,
"input"
:
x
,
"sigma"
:
timestep
}
cfg_result
=
x
-
model_options
[
"sampler_cfg_function"
](
args
)
cfg_result
=
x
-
model_options
[
"sampler_cfg_function"
](
args
)
if
"sag"
in
model_options
:
for
fn
in
model_options
.
get
(
"sampler_post_cfg_function"
,
[]):
assert
uncond
is
not
None
,
"SAG requires uncond guidance"
args
=
{
"denoised"
:
cfg_result
,
"cond"
:
cond
,
"uncond"
:
uncond
,
"model"
:
model
,
"uncond_denoised"
:
uncond_pred
,
"cond_denoised"
:
cond_pred
,
sag_scale
=
model_options
[
"sag_scale"
]
"sigma"
:
timestep
,
"model_options"
:
model_options
,
"input"
:
x
}
sag_sigma
=
model_options
[
"sag_sigma"
]
cfg_result
=
fn
(
args
)
sag_threshold
=
model_options
.
get
(
"sag_threshold"
,
1.0
)
# these methods are added by the sag patcher
uncond_attn
=
model
.
get_attn_scores
()
mid_shape
=
model
.
get_mid_block_shape
()
# create the adversarially blurred image
degraded
=
create_blur_map
(
uncond_pred
,
uncond_attn
,
mid_shape
,
sag_sigma
,
sag_threshold
)
degraded_noised
=
degraded
+
x
-
uncond_pred
# call into the UNet
(
sag
,
_
)
=
calc_cond_uncond_batch
(
model
,
uncond
,
None
,
degraded_noised
,
timestep
,
model_options
)
cfg_result
+=
(
degraded
-
sag
)
*
sag_scale
return
cfg_result
def
create_blur_map
(
x0
,
attn
,
mid_shape
,
sigma
=
3.0
,
threshold
=
1.0
):
# reshape and GAP the attention map
_
,
hw1
,
hw2
=
attn
.
shape
b
,
_
,
lh
,
lw
=
x0
.
shape
attn
=
attn
.
reshape
(
b
,
-
1
,
hw1
,
hw2
)
# Global Average Pool
mask
=
attn
.
mean
(
1
,
keepdim
=
False
).
sum
(
1
,
keepdim
=
False
)
>
threshold
# Reshape
mask
=
(
mask
.
reshape
(
b
,
*
mid_shape
)
.
unsqueeze
(
1
)
.
type
(
attn
.
dtype
)
)
# Upsample
mask
=
F
.
interpolate
(
mask
,
(
lh
,
lw
))
blurred
=
gaussian_blur_2d
(
x0
,
kernel_size
=
9
,
sigma
=
sigma
)
blurred
=
blurred
*
mask
+
x0
*
(
1
-
mask
)
return
blurred
def
gaussian_blur_2d
(
img
,
kernel_size
,
sigma
):
return
cfg_result
ksize_half
=
(
kernel_size
-
1
)
*
0.5
x
=
torch
.
linspace
(
-
ksize_half
,
ksize_half
,
steps
=
kernel_size
)
pdf
=
torch
.
exp
(
-
0.5
*
(
x
/
sigma
).
pow
(
2
))
x_kernel
=
pdf
/
pdf
.
sum
()
x_kernel
=
x_kernel
.
to
(
device
=
img
.
device
,
dtype
=
img
.
dtype
)
kernel2d
=
torch
.
mm
(
x_kernel
[:,
None
],
x_kernel
[
None
,
:])
kernel2d
=
kernel2d
.
expand
(
img
.
shape
[
-
3
],
1
,
kernel2d
.
shape
[
0
],
kernel2d
.
shape
[
1
])
padding
=
[
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
]
img
=
F
.
pad
(
img
,
padding
,
mode
=
"reflect"
)
img
=
F
.
conv2d
(
img
,
kernel2d
,
groups
=
img
.
shape
[
-
3
])
return
img
class
CFGNoisePredictor
(
torch
.
nn
.
Module
):
class
CFGNoisePredictor
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
...
...
comfy_extras/nodes_sag.py
View file @
ba04a87d
import
torch
import
torch
from
torch
import
einsum
from
torch
import
einsum
import
torch.nn.functional
as
F
import
math
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
import
os
import
os
from
comfy.ldm.modules.attention
import
optimized_attention
,
_ATTN_PRECISION
from
comfy.ldm.modules.attention
import
optimized_attention
,
_ATTN_PRECISION
import
comfy.samplers
# from comfy/ldm/modules/attention.py
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
# but modified to return attention scores as well as output
...
@@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
...
@@ -49,7 +53,49 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
)
)
return
(
out
,
sim
)
return
(
out
,
sim
)
class
SagNode
:
def
create_blur_map
(
x0
,
attn
,
sigma
=
3.0
,
threshold
=
1.0
):
# reshape and GAP the attention map
_
,
hw1
,
hw2
=
attn
.
shape
b
,
_
,
lh
,
lw
=
x0
.
shape
attn
=
attn
.
reshape
(
b
,
-
1
,
hw1
,
hw2
)
# Global Average Pool
mask
=
attn
.
mean
(
1
,
keepdim
=
False
).
sum
(
1
,
keepdim
=
False
)
>
threshold
ratio
=
round
(
math
.
sqrt
(
lh
*
lw
/
hw1
))
mid_shape
=
[
math
.
ceil
(
lh
/
ratio
),
math
.
ceil
(
lw
/
ratio
)]
# Reshape
mask
=
(
mask
.
reshape
(
b
,
*
mid_shape
)
.
unsqueeze
(
1
)
.
type
(
attn
.
dtype
)
)
# Upsample
mask
=
F
.
interpolate
(
mask
,
(
lh
,
lw
))
blurred
=
gaussian_blur_2d
(
x0
,
kernel_size
=
9
,
sigma
=
sigma
)
blurred
=
blurred
*
mask
+
x0
*
(
1
-
mask
)
return
blurred
def
gaussian_blur_2d
(
img
,
kernel_size
,
sigma
):
ksize_half
=
(
kernel_size
-
1
)
*
0.5
x
=
torch
.
linspace
(
-
ksize_half
,
ksize_half
,
steps
=
kernel_size
)
pdf
=
torch
.
exp
(
-
0.5
*
(
x
/
sigma
).
pow
(
2
))
x_kernel
=
pdf
/
pdf
.
sum
()
x_kernel
=
x_kernel
.
to
(
device
=
img
.
device
,
dtype
=
img
.
dtype
)
kernel2d
=
torch
.
mm
(
x_kernel
[:,
None
],
x_kernel
[
None
,
:])
kernel2d
=
kernel2d
.
expand
(
img
.
shape
[
-
3
],
1
,
kernel2d
.
shape
[
0
],
kernel2d
.
shape
[
1
])
padding
=
[
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
,
kernel_size
//
2
]
img
=
F
.
pad
(
img
,
padding
,
mode
=
"reflect"
)
img
=
F
.
conv2d
(
img
,
kernel2d
,
groups
=
img
.
shape
[
-
3
])
return
img
class
SelfAttentionGuidance
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
...
@@ -63,15 +109,9 @@ class SagNode:
...
@@ -63,15 +109,9 @@ class SagNode:
def
patch
(
self
,
model
,
scale
,
blur_sigma
):
def
patch
(
self
,
model
,
scale
,
blur_sigma
):
m
=
model
.
clone
()
m
=
model
.
clone
()
# set extra options on the model
m
.
model_options
[
"sag"
]
=
True
m
.
model_options
[
"sag_scale"
]
=
scale
m
.
model_options
[
"sag_sigma"
]
=
blur_sigma
attn_scores
=
None
attn_scores
=
None
mid_block_shape
=
None
mid_block_shape
=
None
m
.
model
.
get_attn_scores
=
lambda
:
attn_scores
m
.
model
.
get_mid_block_shape
=
lambda
:
mid_block_shape
# TODO: make this work properly with chunked batches
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
# currently, we can only save the attn from one UNet call
...
@@ -92,24 +132,41 @@ class SagNode:
...
@@ -92,24 +132,41 @@ class SagNode:
else
:
else
:
return
optimized_attention
(
q
,
k
,
v
,
heads
=
heads
)
return
optimized_attention
(
q
,
k
,
v
,
heads
=
heads
)
def
post_cfg_function
(
args
):
nonlocal
attn_scores
nonlocal
mid_block_shape
uncond_attn
=
attn_scores
sag_scale
=
scale
sag_sigma
=
blur_sigma
sag_threshold
=
1.0
model
=
args
[
"model"
]
uncond_pred
=
args
[
"uncond_denoised"
]
uncond
=
args
[
"uncond"
]
cfg_result
=
args
[
"denoised"
]
sigma
=
args
[
"sigma"
]
model_options
=
args
[
"model_options"
]
x
=
args
[
"input"
]
# create the adversarially blurred image
degraded
=
create_blur_map
(
uncond_pred
,
uncond_attn
,
sag_sigma
,
sag_threshold
)
degraded_noised
=
degraded
+
x
-
uncond_pred
# call into the UNet
(
sag
,
_
)
=
comfy
.
samplers
.
calc_cond_uncond_batch
(
model
,
uncond
,
None
,
degraded_noised
,
sigma
,
model_options
)
return
cfg_result
+
(
degraded
-
sag
)
*
sag_scale
m
.
set_model_sampler_post_cfg_function
(
post_cfg_function
)
# from diffusers:
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
def
set_model_patch_replace
(
patch
,
name
,
key
):
m
.
set_model_attn1_replace
(
attn_and_record
,
"middle"
,
0
,
0
)
to
=
m
.
model_options
[
"transformer_options"
]
if
"patches_replace"
not
in
to
:
to
[
"patches_replace"
]
=
{}
if
name
not
in
to
[
"patches_replace"
]:
to
[
"patches_replace"
][
name
]
=
{}
to
[
"patches_replace"
][
name
][
key
]
=
patch
set_model_patch_replace
(
attn_and_record
,
"attn1"
,
(
"middle"
,
0
,
0
))
# from diffusers:
# unet.mid_block.attentions[0].register_forward_hook()
def
forward_hook
(
m
,
inp
,
out
):
nonlocal
mid_block_shape
mid_block_shape
=
out
[
0
].
shape
[
-
2
:]
m
.
model
.
diffusion_model
.
middle_block
[
0
].
register_forward_hook
(
forward_hook
)
return
(
m
,
)
return
(
m
,
)
NODE_CLASS_MAPPINGS
=
{
NODE_CLASS_MAPPINGS
=
{
"Self-Attention Guidance"
:
SagNode
,
"SelfAttentionGuidance"
:
SelfAttentionGuidance
,
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
"SelfAttentionGuidance"
:
"Self-Attention Guidance"
,
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment