Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
719fb2c8
Commit
719fb2c8
authored
Apr 14, 2024
by
comfyanonymous
Browse files
Add basic PAG node.
parent
258dbc06
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
10 deletions
+78
-10
comfy/model_patcher.py
comfy/model_patcher.py
+21
-10
comfy_extras/nodes_pag.py
comfy_extras/nodes_pag.py
+56
-0
nodes.py
nodes.py
+1
-0
No files found.
comfy/model_patcher.py
View file @
719fb2c8
...
...
@@ -18,6 +18,26 @@ def apply_weight_decompose(dora_scale, weight):
return
weight
*
(
dora_scale
/
weight_norm
)
def
set_model_options_patch_replace
(
model_options
,
patch
,
name
,
block_name
,
number
,
transformer_index
=
None
):
to
=
model_options
[
"transformer_options"
].
copy
()
if
"patches_replace"
not
in
to
:
to
[
"patches_replace"
]
=
{}
else
:
to
[
"patches_replace"
]
=
to
[
"patches_replace"
].
copy
()
if
name
not
in
to
[
"patches_replace"
]:
to
[
"patches_replace"
][
name
]
=
{}
else
:
to
[
"patches_replace"
][
name
]
=
to
[
"patches_replace"
][
name
].
copy
()
if
transformer_index
is
not
None
:
block
=
(
block_name
,
number
,
transformer_index
)
else
:
block
=
(
block_name
,
number
)
to
[
"patches_replace"
][
name
][
block
]
=
patch
model_options
[
"transformer_options"
]
=
to
return
model_options
class
ModelPatcher
:
def
__init__
(
self
,
model
,
load_device
,
offload_device
,
size
=
0
,
current_device
=
None
,
weight_inplace_update
=
False
):
...
...
@@ -109,16 +129,7 @@ class ModelPatcher:
to
[
"patches"
][
name
]
=
to
[
"patches"
].
get
(
name
,
[])
+
[
patch
]
def
set_model_patch_replace
(
self
,
patch
,
name
,
block_name
,
number
,
transformer_index
=
None
):
to
=
self
.
model_options
[
"transformer_options"
]
if
"patches_replace"
not
in
to
:
to
[
"patches_replace"
]
=
{}
if
name
not
in
to
[
"patches_replace"
]:
to
[
"patches_replace"
][
name
]
=
{}
if
transformer_index
is
not
None
:
block
=
(
block_name
,
number
,
transformer_index
)
else
:
block
=
(
block_name
,
number
)
to
[
"patches_replace"
][
name
][
block
]
=
patch
self
.
model_options
=
set_model_options_patch_replace
(
self
.
model_options
,
patch
,
name
,
block_name
,
number
,
transformer_index
=
transformer_index
)
def
set_model_attn1_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn1_patch"
)
...
...
comfy_extras/nodes_pag.py
0 → 100644
View file @
719fb2c8
#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention
#If you want the one with more options see the above repo.
#My modified one here is more basic but has less chances of breaking with ComfyUI updates.
import
comfy.model_patcher
import
comfy.samplers
class
PerturbedAttentionGuidance
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"scale"
:
(
"FLOAT"
,
{
"default"
:
3.0
,
"min"
:
0.0
,
"max"
:
100.0
,
"step"
:
0.1
,
"round"
:
0.01
}),
}
}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"_for_testing"
def
patch
(
self
,
model
,
scale
):
unet_block
=
"middle"
unet_block_id
=
0
m
=
model
.
clone
()
def
perturbed_attention
(
q
,
k
,
v
,
extra_options
,
mask
=
None
):
return
v
def
post_cfg_function
(
args
):
model
=
args
[
"model"
]
cond_pred
=
args
[
"cond_denoised"
]
cond
=
args
[
"cond"
]
cfg_result
=
args
[
"denoised"
]
sigma
=
args
[
"sigma"
]
model_options
=
args
[
"model_options"
].
copy
()
x
=
args
[
"input"
]
if
scale
==
0
:
return
cfg_result
# Replace Self-attention with PAG
model_options
=
comfy
.
model_patcher
.
set_model_options_patch_replace
(
model_options
,
perturbed_attention
,
"attn1"
,
unet_block
,
unet_block_id
)
(
pag
,)
=
comfy
.
samplers
.
calc_cond_batch
(
model
,
[
cond
],
x
,
sigma
,
model_options
)
return
cfg_result
+
(
cond_pred
-
pag
)
*
scale
m
.
set_model_sampler_post_cfg_function
(
post_cfg_function
,
disable_cfg1_optimization
=
True
)
return
(
m
,)
NODE_CLASS_MAPPINGS
=
{
"PerturbedAttentionGuidance"
:
PerturbedAttentionGuidance
,
}
nodes.py
View file @
719fb2c8
...
...
@@ -1942,6 +1942,7 @@ def init_custom_nodes():
"nodes_differential_diffusion.py"
,
"nodes_ip2p.py"
,
"nodes_model_merging_model_specific.py"
,
"nodes_pag.py"
,
]
import_failed
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment