Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
ffc4b7c3
Commit
ffc4b7c3
authored
May 25, 2024
by
comfyanonymous
Browse files
Fix DORA strength.
This is a different version of #3298 with more correct behavior.
parent
5b873694
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
20 deletions
+48
-20
comfy/model_patcher.py
comfy/model_patcher.py
+48
-20
No files found.
comfy/model_patcher.py
View file @
ffc4b7c3
...
@@ -9,16 +9,26 @@ import comfy.model_management
...
@@ -9,16 +9,26 @@ import comfy.model_management
from
comfy.types
import
UnetWrapperFunction
from
comfy.types
import
UnetWrapperFunction
def
weight_decompose_scale
(
dora_scale
,
weight
):
def
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
):
dora_scale
=
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
)
lora_diff
*=
alpha
weight_calc
=
weight
+
lora_diff
.
type
(
weight
.
dtype
)
weight_norm
=
(
weight_norm
=
(
weight
.
transpose
(
0
,
1
)
weight
_calc
.
transpose
(
0
,
1
)
.
reshape
(
weight
.
shape
[
1
],
-
1
)
.
reshape
(
weight
_calc
.
shape
[
1
],
-
1
)
.
norm
(
dim
=
1
,
keepdim
=
True
)
.
norm
(
dim
=
1
,
keepdim
=
True
)
.
reshape
(
weight
.
shape
[
1
],
*
[
1
]
*
(
weight
.
dim
()
-
1
))
.
reshape
(
weight
_calc
.
shape
[
1
],
*
[
1
]
*
(
weight
_calc
.
dim
()
-
1
))
.
transpose
(
0
,
1
)
.
transpose
(
0
,
1
)
)
)
return
(
dora_scale
/
weight_norm
).
type
(
weight
.
dtype
)
weight_calc
*=
(
dora_scale
/
weight_norm
).
type
(
weight
.
dtype
)
if
strength
!=
1.0
:
weight_calc
-=
weight
weight
+=
strength
*
(
weight_calc
)
else
:
weight
[:]
=
weight_calc
return
weight
def
set_model_options_patch_replace
(
model_options
,
patch
,
name
,
block_name
,
number
,
transformer_index
=
None
):
def
set_model_options_patch_replace
(
model_options
,
patch
,
name
,
block_name
,
number
,
transformer_index
=
None
):
to
=
model_options
[
"transformer_options"
].
copy
()
to
=
model_options
[
"transformer_options"
].
copy
()
...
@@ -328,7 +338,7 @@ class ModelPatcher:
...
@@ -328,7 +338,7 @@ class ModelPatcher:
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
for
p
in
patches
:
for
p
in
patches
:
alpha
=
p
[
0
]
strength
=
p
[
0
]
v
=
p
[
1
]
v
=
p
[
1
]
strength_model
=
p
[
2
]
strength_model
=
p
[
2
]
...
@@ -346,26 +356,31 @@ class ModelPatcher:
...
@@ -346,26 +356,31 @@ class ModelPatcher:
if
patch_type
==
"diff"
:
if
patch_type
==
"diff"
:
w1
=
v
[
0
]
w1
=
v
[
0
]
if
alpha
!=
0.0
:
if
strength
!=
0.0
:
if
w1
.
shape
!=
weight
.
shape
:
if
w1
.
shape
!=
weight
.
shape
:
logging
.
warning
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
logging
.
warning
(
"WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}"
.
format
(
key
,
w1
.
shape
,
weight
.
shape
))
else
:
else
:
weight
+=
alpha
*
comfy
.
model_management
.
cast_to_device
(
w1
,
weight
.
device
,
weight
.
dtype
)
weight
+=
strength
*
comfy
.
model_management
.
cast_to_device
(
w1
,
weight
.
device
,
weight
.
dtype
)
elif
patch_type
==
"lora"
:
#lora/locon
elif
patch_type
==
"lora"
:
#lora/locon
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat1
=
comfy
.
model_management
.
cast_to_device
(
v
[
0
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
mat2
=
comfy
.
model_management
.
cast_to_device
(
v
[
1
],
weight
.
device
,
torch
.
float32
)
dora_scale
=
v
[
4
]
dora_scale
=
v
[
4
]
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
alpha
=
v
[
2
]
/
mat2
.
shape
[
0
]
else
:
alpha
=
1.0
if
v
[
3
]
is
not
None
:
if
v
[
3
]
is
not
None
:
#locon mid weights, hopefully the math is fine because I didn't properly test it
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3
=
comfy
.
model_management
.
cast_to_device
(
v
[
3
],
weight
.
device
,
torch
.
float32
)
mat3
=
comfy
.
model_management
.
cast_to_device
(
v
[
3
],
weight
.
device
,
torch
.
float32
)
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
mat3
.
shape
[
2
],
mat3
.
shape
[
3
]]
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
mat3
.
shape
[
2
],
mat3
.
shape
[
3
]]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
),
mat3
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
)).
reshape
(
final_shape
).
transpose
(
0
,
1
)
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
),
mat3
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
)).
reshape
(
final_shape
).
transpose
(
0
,
1
)
try
:
try
:
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
))
)
.
reshape
(
weight
.
shape
)
.
type
(
weight
.
dtype
)
lora_diff
=
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
)).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
*=
weight_decompose_scale
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"lokr"
:
elif
patch_type
==
"lokr"
:
...
@@ -402,19 +417,26 @@ class ModelPatcher:
...
@@ -402,19 +417,26 @@ class ModelPatcher:
if
len
(
w2
.
shape
)
==
4
:
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
alpha
*=
v
[
2
]
/
dim
alpha
=
v
[
2
]
/
dim
else
:
alpha
=
1.0
try
:
try
:
weight
+=
alpha
*
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
)
.
type
(
weight
.
dtype
)
lora_diff
=
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
*=
weight_decompose_scale
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"loha"
:
elif
patch_type
==
"loha"
:
w1a
=
v
[
0
]
w1a
=
v
[
0
]
w1b
=
v
[
1
]
w1b
=
v
[
1
]
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
w1b
.
shape
[
0
]
alpha
=
v
[
2
]
/
w1b
.
shape
[
0
]
else
:
alpha
=
1.0
w2a
=
v
[
3
]
w2a
=
v
[
3
]
w2b
=
v
[
4
]
w2b
=
v
[
4
]
dora_scale
=
v
[
7
]
dora_scale
=
v
[
7
]
...
@@ -437,14 +459,18 @@ class ModelPatcher:
...
@@ -437,14 +459,18 @@ class ModelPatcher:
comfy
.
model_management
.
cast_to_device
(
w2b
,
weight
.
device
,
torch
.
float32
))
comfy
.
model_management
.
cast_to_device
(
w2b
,
weight
.
device
,
torch
.
float32
))
try
:
try
:
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
)
.
type
(
weight
.
dtype
)
lora_diff
=
(
m1
*
m2
).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
*=
weight_decompose_scale
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
elif
patch_type
==
"glora"
:
elif
patch_type
==
"glora"
:
if
v
[
4
]
is
not
None
:
if
v
[
4
]
is
not
None
:
alpha
*=
v
[
4
]
/
v
[
0
].
shape
[
0
]
alpha
=
v
[
4
]
/
v
[
0
].
shape
[
0
]
else
:
alpha
=
1.0
dora_scale
=
v
[
5
]
dora_scale
=
v
[
5
]
...
@@ -454,9 +480,11 @@ class ModelPatcher:
...
@@ -454,9 +480,11 @@ class ModelPatcher:
b2
=
comfy
.
model_management
.
cast_to_device
(
v
[
3
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
b2
=
comfy
.
model_management
.
cast_to_device
(
v
[
3
].
flatten
(
start_dim
=
1
),
weight
.
device
,
torch
.
float32
)
try
:
try
:
weight
+
=
(
(
torch
.
mm
(
b2
,
b1
)
+
torch
.
mm
(
torch
.
mm
(
weight
.
flatten
(
start_dim
=
1
),
a2
),
a1
))
*
alpha
)
.
reshape
(
weight
.
shape
)
.
type
(
weight
.
dtype
)
lora_diff
=
(
torch
.
mm
(
b2
,
b1
)
+
torch
.
mm
(
torch
.
mm
(
weight
.
flatten
(
start_dim
=
1
),
a2
),
a1
)).
reshape
(
weight
.
shape
)
if
dora_scale
is
not
None
:
if
dora_scale
is
not
None
:
weight
*=
weight_decompose_scale
(
comfy
.
model_management
.
cast_to_device
(
dora_scale
,
weight
.
device
,
torch
.
float32
),
weight
)
weight
=
weight_decompose
(
dora_scale
,
weight
,
lora_diff
,
alpha
,
strength
)
else
:
weight
+=
((
strength
*
alpha
)
*
lora_diff
).
type
(
weight
.
dtype
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
logging
.
error
(
"ERROR {} {} {}"
.
format
(
patch_type
,
key
,
e
))
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment