Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
a57b0c79
Commit
a57b0c79
authored
Aug 26, 2023
by
comfyanonymous
Browse files
Fix lowvram model merging.
parent
f72780a7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
7 deletions
+15
-7
comfy/controlnet.py
comfy/controlnet.py
+1
-6
comfy/model_base.py
comfy/model_base.py
+6
-1
comfy/model_management.py
comfy/model_management.py
+8
-0
No files found.
comfy/controlnet.py
View file @
a57b0c79
...
@@ -257,12 +257,7 @@ class ControlLora(ControlNet):
...
@@ -257,12 +257,7 @@ class ControlLora(ControlNet):
cm
=
self
.
control_model
.
state_dict
()
cm
=
self
.
control_model
.
state_dict
()
for
k
in
sd
:
for
k
in
sd
:
weight
=
sd
[
k
]
weight
=
comfy
.
model_management
.
resolve_lowvram_weight
(
sd
[
k
],
diffusion_model
,
k
)
if
weight
.
device
==
torch
.
device
(
"meta"
):
#lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split
=
k
.
split
(
'.'
)
# I have no idea why they don't just leave the weight there instead of using the meta device.
op
=
comfy
.
utils
.
get_attr
(
diffusion_model
,
'.'
.
join
(
key_split
[:
-
1
]))
weight
=
op
.
_hf_hook
.
weights_map
[
key_split
[
-
1
]]
try
:
try
:
comfy
.
utils
.
set_attr
(
self
.
control_model
,
k
,
weight
)
comfy
.
utils
.
set_attr
(
self
.
control_model
,
k
,
weight
)
except
:
except
:
...
...
comfy/model_base.py
View file @
a57b0c79
...
@@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
...
@@ -3,6 +3,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
Timestep
import
comfy.model_management
import
numpy
as
np
import
numpy
as
np
from
enum
import
Enum
from
enum
import
Enum
from
.
import
utils
from
.
import
utils
...
@@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module):
...
@@ -93,7 +94,11 @@ class BaseModel(torch.nn.Module):
def
state_dict_for_saving
(
self
,
clip_state_dict
,
vae_state_dict
):
def
state_dict_for_saving
(
self
,
clip_state_dict
,
vae_state_dict
):
clip_state_dict
=
self
.
model_config
.
process_clip_state_dict_for_saving
(
clip_state_dict
)
clip_state_dict
=
self
.
model_config
.
process_clip_state_dict_for_saving
(
clip_state_dict
)
unet_state_dict
=
self
.
diffusion_model
.
state_dict
()
unet_sd
=
self
.
diffusion_model
.
state_dict
()
unet_state_dict
=
{}
for
k
in
unet_sd
:
unet_state_dict
[
k
]
=
comfy
.
model_management
.
resolve_lowvram_weight
(
unet_sd
[
k
],
self
.
diffusion_model
,
k
)
unet_state_dict
=
self
.
model_config
.
process_unet_state_dict_for_saving
(
unet_state_dict
)
unet_state_dict
=
self
.
model_config
.
process_unet_state_dict_for_saving
(
unet_state_dict
)
vae_state_dict
=
self
.
model_config
.
process_vae_state_dict_for_saving
(
vae_state_dict
)
vae_state_dict
=
self
.
model_config
.
process_vae_state_dict_for_saving
(
vae_state_dict
)
if
self
.
get_dtype
()
==
torch
.
float16
:
if
self
.
get_dtype
()
==
torch
.
float16
:
...
...
comfy/model_management.py
View file @
a57b0c79
import
psutil
import
psutil
from
enum
import
Enum
from
enum
import
Enum
from
comfy.cli_args
import
args
from
comfy.cli_args
import
args
import
comfy.utils
import
torch
import
torch
import
sys
import
sys
...
@@ -637,6 +638,13 @@ def soft_empty_cache():
...
@@ -637,6 +638,13 @@ def soft_empty_cache():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
ipc_collect
()
torch
.
cuda
.
ipc_collect
()
def
resolve_lowvram_weight
(
weight
,
model
,
key
):
if
weight
.
device
==
torch
.
device
(
"meta"
):
#lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split
=
key
.
split
(
'.'
)
# I have no idea why they don't just leave the weight there instead of using the meta device.
op
=
comfy
.
utils
.
get_attr
(
model
,
'.'
.
join
(
key_split
[:
-
1
]))
weight
=
op
.
_hf_hook
.
weights_map
[
key_split
[
-
1
]]
return
weight
#TODO: might be cleaner to put this somewhere else
#TODO: might be cleaner to put this somewhere else
import
threading
import
threading
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment