Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b92bf819
"...ldm/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "a9c57849b7ffd0f3e2c7627af4775b26dacd2cdf"
Commit
b92bf819
authored
Sep 18, 2023
by
comfyanonymous
Browse files
Do lora cast on GPU instead of CPU for higher performance.
parent
01094316
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
comfy/model_patcher.py
comfy/model_patcher.py
+12
-12
No files found.
comfy/model_patcher.py
View file @
b92bf819
...
@@ -187,13 +187,13 @@ class ModelPatcher:
...
@@ -187,13 +187,13 @@ class ModelPatcher:
else
:
else
:
weight
+=
alpha
*
w1
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
weight
+=
alpha
*
w1
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
elif
len
(
v
)
==
4
:
#lora/locon
elif
len
(
v
)
==
4
:
#lora/locon
mat1
=
v
[
0
].
float
().
to
(
weight
.
device
)
mat1
=
v
[
0
].
to
(
weight
.
device
)
.
float
()
mat2
=
v
[
1
].
float
().
to
(
weight
.
device
)
mat2
=
v
[
1
].
to
(
weight
.
device
)
.
float
()
if
v
[
2
]
is
not
None
:
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
if
v
[
3
]
is
not
None
:
if
v
[
3
]
is
not
None
:
#locon mid weights, hopefully the math is fine because I didn't properly test it
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3
=
v
[
3
].
float
().
to
(
weight
.
device
)
mat3
=
v
[
3
].
to
(
weight
.
device
)
.
float
()
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
mat3
.
shape
[
2
],
mat3
.
shape
[
3
]]
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
mat3
.
shape
[
2
],
mat3
.
shape
[
3
]]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
),
mat3
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
)).
reshape
(
final_shape
).
transpose
(
0
,
1
)
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
),
mat3
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
)).
reshape
(
final_shape
).
transpose
(
0
,
1
)
try
:
try
:
...
@@ -212,18 +212,18 @@ class ModelPatcher:
...
@@ -212,18 +212,18 @@ class ModelPatcher:
if
w1
is
None
:
if
w1
is
None
:
dim
=
w1_b
.
shape
[
0
]
dim
=
w1_b
.
shape
[
0
]
w1
=
torch
.
mm
(
w1_a
.
float
(),
w1_b
.
float
())
w1
=
torch
.
mm
(
w1_a
.
to
(
weight
.
device
).
float
(),
w1_b
.
to
(
weight
.
device
)
.
float
())
else
:
else
:
w1
=
w1
.
float
().
to
(
weight
.
device
)
w1
=
w1
.
to
(
weight
.
device
)
.
float
()
if
w2
is
None
:
if
w2
is
None
:
dim
=
w2_b
.
shape
[
0
]
dim
=
w2_b
.
shape
[
0
]
if
t2
is
None
:
if
t2
is
None
:
w2
=
torch
.
mm
(
w2_a
.
float
().
to
(
weight
.
device
)
,
w2_b
.
float
().
to
(
weight
.
device
))
w2
=
torch
.
mm
(
w2_a
.
to
(
weight
.
device
).
float
()
,
w2_b
.
to
(
weight
.
device
)
.
float
()
)
else
:
else
:
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
().
to
(
weight
.
device
)
,
w2_b
.
float
().
to
(
weight
.
device
)
,
w2_a
.
float
().
to
(
weight
.
device
))
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
to
(
weight
.
device
).
float
()
,
w2_b
.
to
(
weight
.
device
).
float
()
,
w2_a
.
to
(
weight
.
device
)
.
float
()
)
else
:
else
:
w2
=
w2
.
float
().
to
(
weight
.
device
)
w2
=
w2
.
to
(
weight
.
device
)
.
float
()
if
len
(
w2
.
shape
)
==
4
:
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
...
@@ -244,11 +244,11 @@ class ModelPatcher:
...
@@ -244,11 +244,11 @@ class ModelPatcher:
if
v
[
5
]
is
not
None
:
#cp decomposition
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t1
=
v
[
5
]
t2
=
v
[
6
]
t2
=
v
[
6
]
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
().
to
(
weight
.
device
)
,
w1b
.
float
().
to
(
weight
.
device
)
,
w1a
.
float
().
to
(
weight
.
device
))
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
to
(
weight
.
device
).
float
()
,
w1b
.
to
(
weight
.
device
).
float
()
,
w1a
.
to
(
weight
.
device
)
.
float
()
)
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
().
to
(
weight
.
device
)
,
w2b
.
float
().
to
(
weight
.
device
)
,
w2a
.
float
().
to
(
weight
.
device
))
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
to
(
weight
.
device
).
float
()
,
w2b
.
to
(
weight
.
device
).
float
()
,
w2a
.
to
(
weight
.
device
)
.
float
()
)
else
:
else
:
m1
=
torch
.
mm
(
w1a
.
float
().
to
(
weight
.
device
)
,
w1b
.
float
().
to
(
weight
.
device
))
m1
=
torch
.
mm
(
w1a
.
to
(
weight
.
device
).
float
()
,
w1b
.
to
(
weight
.
device
)
.
float
()
)
m2
=
torch
.
mm
(
w2a
.
float
().
to
(
weight
.
device
)
,
w2b
.
float
().
to
(
weight
.
device
))
m2
=
torch
.
mm
(
w2a
.
to
(
weight
.
device
).
float
()
,
w2b
.
to
(
weight
.
device
)
.
float
()
)
try
:
try
:
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment