Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
490771b7
Commit
490771b7
authored
Jul 15, 2023
by
comfyanonymous
Browse files
Speed up lora loading a bit.
parent
50b1180d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
25 deletions
+35
-25
comfy/model_management.py
comfy/model_management.py
+10
-7
comfy/sd.py
comfy/sd.py
+19
-14
comfy/utils.py
comfy/utils.py
+6
-4
No files found.
comfy/model_management.py
View file @
490771b7
...
...
@@ -258,15 +258,11 @@ def load_model_gpu(model):
if
model
is
current_loaded_model
:
return
unload_model
()
try
:
real_model
=
model
.
patch_model
()
except
Exception
as
e
:
model
.
unpatch_model
()
raise
e
torch_dev
=
model
.
load_device
model
.
model_patches_to
(
torch_dev
)
model
.
model_patches_to
(
model
.
model_dtype
())
current_loaded_model
=
model
if
is_device_cpu
(
torch_dev
):
vram_set_state
=
VRAMState
.
DISABLED
...
...
@@ -280,8 +276,7 @@ def load_model_gpu(model):
if
model_size
>
(
current_free_mem
-
minimum_inference_memory
()):
#only switch to lowvram if really necessary
vram_set_state
=
VRAMState
.
LOW_VRAM
current_loaded_model
=
model
real_model
=
model
.
model
if
vram_set_state
==
VRAMState
.
DISABLED
:
pass
elif
vram_set_state
==
VRAMState
.
NORMAL_VRAM
or
vram_set_state
==
VRAMState
.
HIGH_VRAM
or
vram_set_state
==
VRAMState
.
SHARED
:
...
...
@@ -295,6 +290,14 @@ def load_model_gpu(model):
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
torch_dev
)
model_accelerated
=
True
try
:
real_model
=
model
.
patch_model
()
except
Exception
as
e
:
model
.
unpatch_model
()
unload_model
()
raise
e
return
current_loaded_model
def
load_controlnet_gpu
(
control_models
):
...
...
comfy/sd.py
View file @
490771b7
...
...
@@ -340,7 +340,7 @@ class ModelPatcher:
weight
=
model_sd
[
key
]
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
clone
(
)
self
.
backup
[
key
]
=
weight
.
to
(
self
.
offload_device
,
copy
=
True
)
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
weight
[:]
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
...
...
@@ -367,15 +367,16 @@ class ModelPatcher:
else
:
weight
+=
alpha
*
w1
.
type
(
weight
.
dtype
).
to
(
weight
.
device
)
elif
len
(
v
)
==
4
:
#lora/locon
mat1
=
v
[
0
]
mat2
=
v
[
1
]
mat1
=
v
[
0
]
.
float
().
to
(
weight
.
device
)
mat2
=
v
[
1
]
.
float
().
to
(
weight
.
device
)
if
v
[
2
]
is
not
None
:
alpha
*=
v
[
2
]
/
mat2
.
shape
[
0
]
if
v
[
3
]
is
not
None
:
#locon mid weights, hopefully the math is fine because I didn't properly test it
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
v
[
3
].
shape
[
2
],
v
[
3
].
shape
[
3
]]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
(),
v
[
3
].
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
).
float
()).
reshape
(
final_shape
).
transpose
(
0
,
1
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
).
float
(),
mat2
.
flatten
(
start_dim
=
1
).
float
())).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
).
to
(
weight
.
device
)
mat3
=
v
[
3
].
float
().
to
(
weight
.
device
)
final_shape
=
[
mat2
.
shape
[
1
],
mat2
.
shape
[
0
],
mat3
.
shape
[
2
],
mat3
.
shape
[
3
]]
mat2
=
torch
.
mm
(
mat2
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
),
mat3
.
transpose
(
0
,
1
).
flatten
(
start_dim
=
1
)).
reshape
(
final_shape
).
transpose
(
0
,
1
)
weight
+=
(
alpha
*
torch
.
mm
(
mat1
.
flatten
(
start_dim
=
1
),
mat2
.
flatten
(
start_dim
=
1
))).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
elif
len
(
v
)
==
8
:
#lokr
w1
=
v
[
0
]
w2
=
v
[
1
]
...
...
@@ -389,20 +390,24 @@ class ModelPatcher:
if
w1
is
None
:
dim
=
w1_b
.
shape
[
0
]
w1
=
torch
.
mm
(
w1_a
.
float
(),
w1_b
.
float
())
else
:
w1
=
w1
.
float
().
to
(
weight
.
device
)
if
w2
is
None
:
dim
=
w2_b
.
shape
[
0
]
if
t2
is
None
:
w2
=
torch
.
mm
(
w2_a
.
float
(),
w2_b
.
float
())
w2
=
torch
.
mm
(
w2_a
.
float
().
to
(
weight
.
device
),
w2_b
.
float
().
to
(
weight
.
device
))
else
:
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
().
to
(
weight
.
device
),
w2_b
.
float
().
to
(
weight
.
device
),
w2_a
.
float
().
to
(
weight
.
device
))
else
:
w2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
(),
w2_b
.
float
(),
w2_a
.
float
()
)
w2
=
w2
.
float
().
to
(
weight
.
device
)
if
len
(
w2
.
shape
)
==
4
:
w1
=
w1
.
unsqueeze
(
2
).
unsqueeze
(
2
)
if
v
[
2
]
is
not
None
and
dim
is
not
None
:
alpha
*=
v
[
2
]
/
dim
weight
+=
alpha
*
torch
.
kron
(
w1
.
float
(),
w2
.
float
()
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
.
to
(
weight
.
device
)
weight
+=
alpha
*
torch
.
kron
(
w1
,
w2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
else
:
#loha
w1a
=
v
[
0
]
w1b
=
v
[
1
]
...
...
@@ -413,13 +418,13 @@ class ModelPatcher:
if
v
[
5
]
is
not
None
:
#cp decomposition
t1
=
v
[
5
]
t2
=
v
[
6
]
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
()
,
w1b
.
float
(),
w1a
.
float
(
))
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
()
,
w2b
.
float
(),
w2a
.
float
(
))
m1
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t1
.
float
()
.
to
(
weight
.
device
),
w1b
.
float
().
to
(
weight
.
device
),
w1a
.
float
().
to
(
weight
.
device
))
m2
=
torch
.
einsum
(
'i j k l, j r, i p -> p r k l'
,
t2
.
float
()
.
to
(
weight
.
device
),
w2b
.
float
().
to
(
weight
.
device
),
w2a
.
float
().
to
(
weight
.
device
))
else
:
m1
=
torch
.
mm
(
w1a
.
float
()
,
w1b
.
float
(
))
m2
=
torch
.
mm
(
w2a
.
float
()
,
w2b
.
float
(
))
m1
=
torch
.
mm
(
w1a
.
float
()
.
to
(
weight
.
device
),
w1b
.
float
().
to
(
weight
.
device
))
m2
=
torch
.
mm
(
w2a
.
float
()
.
to
(
weight
.
device
),
w2b
.
float
().
to
(
weight
.
device
))
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
.
to
(
weight
.
device
)
weight
+=
(
alpha
*
m1
*
m2
).
reshape
(
weight
.
shape
).
type
(
weight
.
dtype
)
return
weight
def
unpatch_model
(
self
):
...
...
comfy/utils.py
View file @
490771b7
...
...
@@ -4,18 +4,20 @@ import struct
import
comfy.checkpoint_pickle
import
safetensors.torch
def
load_torch_file
(
ckpt
,
safe_load
=
False
):
def
load_torch_file
(
ckpt
,
safe_load
=
False
,
device
=
None
):
if
device
is
None
:
device
=
torch
.
device
(
"cpu"
)
if
ckpt
.
lower
().
endswith
(
".safetensors"
):
sd
=
safetensors
.
torch
.
load_file
(
ckpt
,
device
=
"cpu"
)
sd
=
safetensors
.
torch
.
load_file
(
ckpt
,
device
=
device
.
type
)
else
:
if
safe_load
:
if
not
'weights_only'
in
torch
.
load
.
__code__
.
co_varnames
:
print
(
"Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely."
)
safe_load
=
False
if
safe_load
:
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
,
weights_only
=
True
)
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
device
,
weights_only
=
True
)
else
:
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
"cpu"
,
pickle_module
=
comfy
.
checkpoint_pickle
)
pl_sd
=
torch
.
load
(
ckpt
,
map_location
=
device
,
pickle_module
=
comfy
.
checkpoint_pickle
)
if
"global_step"
in
pl_sd
:
print
(
f
"Global Step:
{
pl_sd
[
'global_step'
]
}
"
)
if
"state_dict"
in
pl_sd
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment