Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
028a583b
Commit
028a583b
authored
Jun 19, 2024
by
comfyanonymous
Browse files
Fix issue with full diffusers SD3 loras.
parent
0d6a5793
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
10 deletions
+16
-10
comfy/model_patcher.py
comfy/model_patcher.py
+16
-10
No files found.
comfy/model_patcher.py
View file @
028a583b
...
@@ -210,16 +210,19 @@ class ModelPatcher:
...
@@ -210,16 +210,19 @@ class ModelPatcher:
model_sd
=
self
.
model
.
state_dict
()
model_sd
=
self
.
model
.
state_dict
()
for
k
in
patches
:
for
k
in
patches
:
offset
=
None
offset
=
None
function
=
None
if
isinstance
(
k
,
str
):
if
isinstance
(
k
,
str
):
key
=
k
key
=
k
else
:
else
:
offset
=
k
[
1
]
offset
=
k
[
1
]
key
=
k
[
0
]
key
=
k
[
0
]
if
len
(
k
)
>
2
:
function
=
k
[
2
]
if
key
in
model_sd
:
if
key
in
model_sd
:
p
.
add
(
k
)
p
.
add
(
k
)
current_patches
=
self
.
patches
.
get
(
key
,
[])
current_patches
=
self
.
patches
.
get
(
key
,
[])
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
,
offset
))
current_patches
.
append
((
strength_patch
,
patches
[
k
],
strength_model
,
offset
,
function
))
self
.
patches
[
key
]
=
current_patches
self
.
patches
[
key
]
=
current_patches
self
.
patches_uuid
=
uuid
.
uuid4
()
self
.
patches_uuid
=
uuid
.
uuid4
()
...
@@ -347,6 +350,9 @@ class ModelPatcher:
...
@@ -347,6 +350,9 @@ class ModelPatcher:
v
=
p
[
1
]
v
=
p
[
1
]
strength_model
=
p
[
2
]
strength_model
=
p
[
2
]
offset
=
p
[
3
]
offset
=
p
[
3
]
function
=
p
[
4
]
if
function
is
None
:
function
=
lambda
a
:
a
old_weight
=
None
old_weight
=
None
if
offset
is
not
None
:
if
offset
is
not
None
:
...
@@ -371,7 +377,7 @@ class ModelPatcher:
...
@@ -371,7 +377,7 @@ class ModelPatcher:
if
w1
.
shape
!=
weight
.
shape
:
if
w1
.
shape
!=
weight
.
shape
:
logging
.
warning
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
logging
.
warning
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
else
:
else
:
weight
+=
strength
*
comfy
.
model_management
.
cast_to_device
(
w1
,
weight
.
device
,
weight
.
dtype
)
weight
+=
function
(
strength
*
comfy
.
model_management
.
cast_to_device
(
w1
,
weight
.
device
,
weight
.
dtype
)
)
elif
patch_type
==
"lora"
:
#lora/locon
elif
patch_type
==
"lora"
:
#lora/locon
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
...
@@ -389,9 +395,9 @@ class ModelPatcher:
...
@@ -389,9 +395,9 @@ class ModelPatcher:
try
:
try
:
lora_diff
=
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
)).
reshape
(
weight
.
shape
)
lora_diff
=
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
)).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
weight
=
function
(
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
)
else
:
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
weight
+=
function
(
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"lokr"
:
elif
patch_type
==
"lokr"
:
...
@@ -435,9 +441,9 @@ class ModelPatcher:
...
@@ -435,9 +441,9 @@ class ModelPatcher:
try
:
try
:
lora_diff
=
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
)
lora_diff
=
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
weight
=
function
(
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
)
else
:
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
weight
+=
function
(
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"loha"
:
elif
patch_type
==
"loha"
:
...
@@ -472,9 +478,9 @@ class ModelPatcher:
...
@@ -472,9 +478,9 @@ class ModelPatcher:
try
:
try
:
lora_diff
=
(
m1
*
m2
).
reshape
(
weight
.
shape
)
lora_diff
=
(
m1
*
m2
).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
weight
=
function
(
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
)
else
:
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
weight
+=
function
(
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"glora"
:
elif
patch_type
==
"glora"
:
...
@@ -493,9 +499,9 @@ class ModelPatcher:
...
@@ -493,9 +499,9 @@ class ModelPatcher:
try
:
try
:
lora_diff
=
(
torch
.
mm
(
b2
,
b1
)
+
torch
.
mm
(
torch
.
mm
(
weight
.
flatten
(
start_dim
=
1
),
a2
),
a1
)).
reshape
(
weight
.
shape
)
lora_diff
=
(
torch
.
mm
(
b2
,
b1
)
+
torch
.
mm
(
torch
.
mm
(
weight
.
flatten
(
start_dim
=
1
),
a2
),
a1
)).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
weight
=
function
(
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
)
else
:
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
weight
+=
function
(
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment