Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
22f29d66
Commit
22f29d66
authored
Jul 22, 2023
by
comfyanonymous
Browse files
Try to fix memory issue with lora.
parent
67be7eb8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
5 deletions
+12
-5
comfy/model_management.py
comfy/model_management.py
+6
-2
comfy/sd.py
comfy/sd.py
+6
-3
No files found.
comfy/model_management.py
View file @
22f29d66
...
@@ -281,19 +281,23 @@ def load_model_gpu(model):
...
@@ -281,19 +281,23 @@ def load_model_gpu(model):
vram_set_state
=
VRAMState
.
LOW_VRAM
vram_set_state
=
VRAMState
.
LOW_VRAM
real_model
=
model
.
model
real_model
=
model
.
model
patch_model_to
=
None
if
vram_set_state
==
VRAMState
.
DISABLED
:
if
vram_set_state
==
VRAMState
.
DISABLED
:
pass
pass
elif
vram_set_state
==
VRAMState
.
NORMAL_VRAM
or
vram_set_state
==
VRAMState
.
HIGH_VRAM
or
vram_set_state
==
VRAMState
.
SHARED
:
elif
vram_set_state
==
VRAMState
.
NORMAL_VRAM
or
vram_set_state
==
VRAMState
.
HIGH_VRAM
or
vram_set_state
==
VRAMState
.
SHARED
:
model_accelerated
=
False
model_accelerated
=
False
real
_model
.
to
(
torch_dev
)
patch
_model
_
to
=
torch_dev
try
:
try
:
real_model
=
model
.
patch_model
()
real_model
=
model
.
patch_model
(
device_to
=
patch_model_to
)
except
Exception
as
e
:
except
Exception
as
e
:
model
.
unpatch_model
()
model
.
unpatch_model
()
unload_model
()
unload_model
()
raise
e
raise
e
if
patch_model_to
is
not
None
:
real_model
.
to
(
torch_dev
)
if
vram_set_state
==
VRAMState
.
NO_VRAM
:
if
vram_set_state
==
VRAMState
.
NO_VRAM
:
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"256MiB"
,
"cpu"
:
"16GiB"
})
device_map
=
accelerate
.
infer_auto_device_map
(
real_model
,
max_memory
=
{
0
:
"256MiB"
,
"cpu"
:
"16GiB"
})
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
torch_dev
)
accelerate
.
dispatch_model
(
real_model
,
device_map
=
device_map
,
main_device
=
torch_dev
)
...
...
comfy/sd.py
View file @
22f29d66
...
@@ -338,7 +338,7 @@ class ModelPatcher:
...
@@ -338,7 +338,7 @@ class ModelPatcher:
sd
.
pop
(
k
)
sd
.
pop
(
k
)
return
sd
return
sd
def
patch_model
(
self
):
def
patch_model
(
self
,
device_to
=
None
):
model_sd
=
self
.
model_state_dict
()
model_sd
=
self
.
model_state_dict
()
for
key
in
self
.
patches
:
for
key
in
self
.
patches
:
if
key
not
in
model_sd
:
if
key
not
in
model_sd
:
...
@@ -350,10 +350,13 @@ class ModelPatcher:
...
@@ -350,10 +350,13 @@ class ModelPatcher:
if
key
not
in
self
.
backup
:
if
key
not
in
self
.
backup
:
self
.
backup
[
key
]
=
weight
.
to
(
self
.
offload_device
)
self
.
backup
[
key
]
=
weight
.
to
(
self
.
offload_device
)
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
if
device_to
is
not
None
:
temp_weight
=
weight
.
float
().
to
(
device_to
,
copy
=
True
)
else
:
temp_weight
=
weight
.
to
(
torch
.
float32
,
copy
=
True
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
out_weight
=
self
.
calculate_weight
(
self
.
patches
[
key
],
temp_weight
,
key
).
to
(
weight
.
dtype
)
set_attr
(
self
.
model
,
key
,
out_weight
)
set_attr
(
self
.
model
,
key
,
out_weight
)
del
weight
del
temp_
weight
return
self
.
model
return
self
.
model
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
def
calculate_weight
(
self
,
patches
,
weight
,
key
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment