Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
5282f564
Commit
5282f564
authored
Apr 23, 2023
by
comfyanonymous
Browse files
Implement Linear hypernetworks.
Add a HypernetworkLoader node to use hypernetworks.
parent
6908f9c9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
185 additions
and
16 deletions
+185
-16
comfy/ldm/modules/attention.py
comfy/ldm/modules/attention.py
+56
-13
comfy/model_management.py
comfy/model_management.py
+3
-0
comfy/samplers.py
comfy/samplers.py
+9
-1
comfy/sd.py
comfy/sd.py
+23
-0
comfy/utils.py
comfy/utils.py
+5
-2
comfy_extras/nodes_hypernetwork.py
comfy_extras/nodes_hypernetwork.py
+87
-0
folder_paths.py
folder_paths.py
+1
-0
models/hypernetworks/put_hypernetworks_here
models/hypernetworks/put_hypernetworks_here
+0
-0
nodes.py
nodes.py
+1
-0
No files found.
comfy/ldm/modules/attention.py
View file @
5282f564
...
...
@@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module):
nn
.
Dropout
(
dropout
)
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
h
=
self
.
heads
query
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
key
=
self
.
to_k
(
context
)
value
=
self
.
to_v
(
context
)
if
value
is
not
None
:
value
=
self
.
to_v
(
value
)
else
:
value
=
self
.
to_v
(
context
)
del
context
,
x
query
=
query
.
unflatten
(
-
1
,
(
self
.
heads
,
-
1
)).
transpose
(
1
,
2
).
flatten
(
end_dim
=
1
)
...
...
@@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module):
nn
.
Dropout
(
dropout
)
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
h
=
self
.
heads
q_in
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k_in
=
self
.
to_k
(
context
)
v_in
=
self
.
to_v
(
context
)
if
value
is
not
None
:
v_in
=
self
.
to_v
(
value
)
del
value
else
:
v_in
=
self
.
to_v
(
context
)
del
context
,
x
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> (b h) n d'
,
h
=
h
),
(
q_in
,
k_in
,
v_in
))
...
...
@@ -350,13 +358,17 @@ class CrossAttention(nn.Module):
nn
.
Dropout
(
dropout
)
)
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
h
=
self
.
heads
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
if
value
is
not
None
:
v
=
self
.
to_v
(
value
)
del
value
else
:
v
=
self
.
to_v
(
context
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b n (h d) -> (b h) n d'
,
h
=
h
),
(
q
,
k
,
v
))
...
...
@@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module):
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
if
value
is
not
None
:
v
=
self
.
to_v
(
value
)
del
value
else
:
v
=
self
.
to_v
(
context
)
b
,
_
,
_
=
q
.
shape
q
,
k
,
v
=
map
(
...
...
@@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module):
self
.
to_out
=
nn
.
Sequential
(
nn
.
Linear
(
inner_dim
,
query_dim
),
nn
.
Dropout
(
dropout
))
self
.
attention_op
:
Optional
[
Any
]
=
None
def
forward
(
self
,
x
,
context
=
None
,
mask
=
None
):
def
forward
(
self
,
x
,
context
=
None
,
value
=
None
,
mask
=
None
):
q
=
self
.
to_q
(
x
)
context
=
default
(
context
,
x
)
k
=
self
.
to_k
(
context
)
v
=
self
.
to_v
(
context
)
if
value
is
not
None
:
v
=
self
.
to_v
(
value
)
del
value
else
:
v
=
self
.
to_v
(
context
)
b
,
_
,
_
=
q
.
shape
q
,
k
,
v
=
map
(
...
...
@@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module):
transformer_patches
=
{}
n
=
self
.
norm1
(
x
)
if
self
.
disable_self_attn
:
context_attn1
=
context
else
:
context_attn1
=
None
value_attn1
=
None
if
"attn1_patch"
in
transformer_patches
:
patch
=
transformer_patches
[
"attn1_patch"
]
if
context_attn1
is
None
:
context_attn1
=
n
value_attn1
=
context_attn1
for
p
in
patch
:
n
,
context_attn1
,
value_attn1
=
p
(
current_index
,
n
,
context_attn1
,
value_attn1
)
if
"tomesd"
in
transformer_options
:
m
,
u
=
tomesd
.
get_functions
(
x
,
transformer_options
[
"tomesd"
][
"ratio"
],
transformer_options
[
"original_shape"
])
n
=
u
(
self
.
attn1
(
m
(
n
),
context
=
context
if
self
.
disable_self_attn
else
None
))
n
=
u
(
self
.
attn1
(
m
(
n
),
context
=
context
_attn1
,
value
=
value_attn1
))
else
:
n
=
self
.
attn1
(
n
,
context
=
context
if
self
.
disable_self_attn
else
None
)
n
=
self
.
attn1
(
n
,
context
=
context
_attn1
,
value
=
value_attn1
)
x
+=
n
if
"middle_patch"
in
transformer_patches
:
...
...
@@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module):
x
=
p
(
current_index
,
x
)
n
=
self
.
norm2
(
x
)
n
=
self
.
attn2
(
n
,
context
=
context
)
context_attn2
=
context
value_attn2
=
None
if
"attn2_patch"
in
transformer_patches
:
patch
=
transformer_patches
[
"attn2_patch"
]
value_attn2
=
context_attn2
for
p
in
patch
:
n
,
context_attn2
,
value_attn2
=
p
(
current_index
,
n
,
context_attn2
,
value_attn2
)
n
=
self
.
attn2
(
n
,
context
=
context_attn2
,
value
=
value_attn2
)
x
+=
n
x
=
self
.
ff
(
self
.
norm3
(
x
))
+
x
...
...
comfy/model_management.py
View file @
5282f564
...
...
@@ -133,6 +133,7 @@ def unload_model():
#never unload models from GPU on high vram
if
vram_state
!=
VRAMState
.
HIGH_VRAM
:
current_loaded_model
.
model
.
cpu
()
current_loaded_model
.
model_patches_to
(
"cpu"
)
current_loaded_model
.
unpatch_model
()
current_loaded_model
=
None
...
...
@@ -156,6 +157,8 @@ def load_model_gpu(model):
except
Exception
as
e
:
model
.
unpatch_model
()
raise
e
model
.
model_patches_to
(
get_torch_device
())
current_loaded_model
=
model
if
vram_state
==
VRAMState
.
CPU
:
pass
...
...
comfy/samplers.py
View file @
5282f564
...
...
@@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
transformer_options
=
model_options
[
'transformer_options'
].
copy
()
if
patches
is
not
None
:
transformer_options
[
"patches"
]
=
patches
if
"patches"
in
transformer_options
:
cur_patches
=
transformer_options
[
"patches"
].
copy
()
for
p
in
patches
:
if
p
in
cur_patches
:
cur_patches
[
p
]
=
cur_patches
[
p
]
+
patches
[
p
]
else
:
cur_patches
[
p
]
=
patches
[
p
]
else
:
transformer_options
[
"patches"
]
=
patches
c
[
'transformer_options'
]
=
transformer_options
...
...
comfy/sd.py
View file @
5282f564
...
...
@@ -254,6 +254,29 @@ class ModelPatcher:
def
set_model_sampler_cfg_function
(
self
,
sampler_cfg_function
):
self
.
model_options
[
"sampler_cfg_function"
]
=
sampler_cfg_function
def
set_model_patch
(
self
,
patch
,
name
):
to
=
self
.
model_options
[
"transformer_options"
]
if
"patches"
not
in
to
:
to
[
"patches"
]
=
{}
to
[
"patches"
][
name
]
=
to
[
"patches"
].
get
(
name
,
[])
+
[
patch
]
def
set_model_attn1_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn1_patch"
)
def
set_model_attn2_patch
(
self
,
patch
):
self
.
set_model_patch
(
patch
,
"attn2_patch"
)
def
model_patches_to
(
self
,
device
):
to
=
self
.
model_options
[
"transformer_options"
]
if
"patches"
in
to
:
patches
=
to
[
"patches"
]
for
name
in
patches
:
patch_list
=
patches
[
name
]
for
i
in
range
(
len
(
patch_list
)):
if
hasattr
(
patch_list
[
i
],
"to"
):
patch_list
[
i
]
=
patch_list
[
i
].
to
(
device
)
def
model_dtype
(
self
):
return
self
.
model
.
diffusion_model
.
dtype
...
...
comfy/utils.py
View file @
5282f564
import
torch
def
load_torch_file
(
ckpt
):
def
load_torch_file
(
ckpt
,
safe_load
=
False
):
if
ckpt
.
lower
().
endswith
(
".safetensors"
):
import
safetensors.torch
sd
=
safetensors
.
torch
.
load_file
(
ckpt
,
device
=
"cpu"
)
else
:
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
if
safe_load
:
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
,
weights_only
=
True
)
else
:
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
)
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
if
"state_dict"
in
pl_sd
:
...
...
comfy_extras/nodes_hypernetwork.py
0 → 100644
View file @
5282f564
import
comfy.utils
import
folder_paths
import
torch
def
load_hypernetwork_patch
(
path
,
strength
):
sd
=
comfy
.
utils
.
load_torch_file
(
path
,
safe_load
=
True
)
activation_func
=
sd
.
get
(
'activation_func'
,
'linear'
)
is_layer_norm
=
sd
.
get
(
'is_layer_norm'
,
False
)
use_dropout
=
sd
.
get
(
'use_dropout'
,
False
)
activate_output
=
sd
.
get
(
'activate_output'
,
False
)
last_layer_dropout
=
sd
.
get
(
'last_layer_dropout'
,
False
)
if
activation_func
!=
'linear'
or
is_layer_norm
!=
False
or
use_dropout
!=
False
or
activate_output
!=
False
or
last_layer_dropout
!=
False
:
print
(
"Unsupported Hypernetwork format, if you report it I might implement it."
,
path
,
" "
,
activation_func
,
is_layer_norm
,
use_dropout
,
activate_output
,
last_layer_dropout
)
return
None
out
=
{}
for
d
in
sd
:
try
:
dim
=
int
(
d
)
except
:
continue
output
=
[]
for
index
in
[
0
,
1
]:
attn_weights
=
sd
[
dim
][
index
]
keys
=
attn_weights
.
keys
()
linears
=
filter
(
lambda
a
:
a
.
endswith
(
".weight"
),
keys
)
linears
=
sorted
(
list
(
map
(
lambda
a
:
a
[:
-
len
(
".weight"
)],
linears
)))
layers
=
[]
for
lin_name
in
linears
:
lin_weight
=
attn_weights
[
'{}.weight'
.
format
(
lin_name
)]
lin_bias
=
attn_weights
[
'{}.bias'
.
format
(
lin_name
)]
layer
=
torch
.
nn
.
Linear
(
lin_weight
.
shape
[
1
],
lin_weight
.
shape
[
0
])
layer
.
load_state_dict
({
"weight"
:
lin_weight
,
"bias"
:
lin_bias
})
layers
+=
[
layer
]
output
.
append
(
torch
.
nn
.
Sequential
(
*
layers
))
out
[
dim
]
=
torch
.
nn
.
ModuleList
(
output
)
class
hypernetwork_patch
:
def
__init__
(
self
,
hypernet
,
strength
):
self
.
hypernet
=
hypernet
self
.
strength
=
strength
def
__call__
(
self
,
current_index
,
q
,
k
,
v
):
dim
=
k
.
shape
[
-
1
]
if
dim
in
self
.
hypernet
:
hn
=
self
.
hypernet
[
dim
]
k
=
k
+
hn
[
0
](
k
)
*
self
.
strength
v
=
v
+
hn
[
1
](
v
)
*
self
.
strength
return
q
,
k
,
v
def
to
(
self
,
device
):
for
d
in
self
.
hypernet
.
keys
():
self
.
hypernet
[
d
]
=
self
.
hypernet
[
d
].
to
(
device
)
return
self
return
hypernetwork_patch
(
out
,
strength
)
class
HypernetworkLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"hypernetwork_name"
:
(
folder_paths
.
get_filename_list
(
"hypernetworks"
),
),
"strength"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
-
10.0
,
"max"
:
10.0
,
"step"
:
0.01
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"load_hypernetwork"
CATEGORY
=
"_for_testing"
def
load_hypernetwork
(
self
,
model
,
hypernetwork_name
,
strength
):
hypernetwork_path
=
folder_paths
.
get_full_path
(
"hypernetworks"
,
hypernetwork_name
)
model_hypernetwork
=
model
.
clone
()
patch
=
load_hypernetwork_patch
(
hypernetwork_path
,
strength
)
if
patch
is
not
None
:
model_hypernetwork
.
set_model_attn1_patch
(
patch
)
model_hypernetwork
.
set_model_attn2_patch
(
patch
)
return
(
model_hypernetwork
,)
NODE_CLASS_MAPPINGS
=
{
"HypernetworkLoader"
:
HypernetworkLoader
}
folder_paths.py
View file @
5282f564
...
...
@@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m
folder_names_and_paths
[
"custom_nodes"
]
=
([
os
.
path
.
join
(
base_path
,
"custom_nodes"
)],
[])
folder_names_and_paths
[
"hypernetworks"
]
=
([
os
.
path
.
join
(
models_dir
,
"hypernetworks"
)],
supported_pt_extensions
)
output_directory
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"output"
)
temp_directory
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"temp"
)
...
...
models/hypernetworks/put_hypernetworks_here
0 → 100644
View file @
5282f564
nodes.py
View file @
5282f564
...
...
@@ -1268,6 +1268,7 @@ def load_custom_nodes():
def
init_custom_nodes
():
load_custom_nodes
()
load_custom_node
(
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"comfy_extras"
),
"nodes_hypernetwork.py"
))
load_custom_node
(
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"comfy_extras"
),
"nodes_upscale_model.py"
))
load_custom_node
(
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"comfy_extras"
),
"nodes_post_processing.py"
))
load_custom_node
(
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"comfy_extras"
),
"nodes_mask.py"
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment