Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c45d1b9b
"vscode:/vscode.git/clone" did not exist on "010202ff0a223a057e30485323c960f1a5ffc137"
Commit
c45d1b9b
authored
Nov 27, 2023
by
comfyanonymous
Browse files
Add a function to load a unet from a state dict.
parent
f30b992b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
4 deletions
+10
-4
comfy/sd.py
comfy/sd.py
+10
-4
No files found.
comfy/sd.py
View file @
c45d1b9b
...
...
@@ -481,20 +481,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return
(
model_patcher
,
clip
,
vae
,
clipvision
)
def
load_unet
(
unet_path
):
#load unet in diffusers format
sd
=
comfy
.
utils
.
load_torch_file
(
unet_path
)
def
load_unet_state_dict
(
sd
):
#load unet in diffusers format
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
if
"input_blocks.0.0.weight"
in
sd
:
#ldm
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
,
unet_dtype
)
if
model_config
is
None
:
r
aise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
unet_path
))
r
eturn
None
new_sd
=
sd
else
:
#diffusers
model_config
=
model_detection
.
model_config_from_diffusers_unet
(
sd
,
unet_dtype
)
if
model_config
is
None
:
print
(
"ERROR UNSUPPORTED UNET"
,
unet_path
)
return
None
diffusers_keys
=
comfy
.
utils
.
unet_to_diffusers
(
model_config
.
unet_config
)
...
...
@@ -514,6 +512,14 @@ def load_unet(unet_path): #load unet in diffusers format
print
(
"left over keys in unet:"
,
left_over
)
return
comfy
.
model_patcher
.
ModelPatcher
(
model
,
load_device
=
model_management
.
get_torch_device
(),
offload_device
=
offload_device
)
def
load_unet
(
unet_path
):
sd
=
comfy
.
utils
.
load_torch_file
(
unet_path
)
model
=
load_unet_state_dict
(
sd
)
if
model
is
None
:
print
(
"ERROR UNSUPPORTED UNET"
,
unet_path
)
raise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
unet_path
))
return
model
def
save_checkpoint
(
output_path
,
model
,
clip
,
vae
,
metadata
=
None
):
model_management
.
load_models_gpu
([
model
,
clip
.
load_model
()])
sd
=
model
.
model
.
state_dict_for_saving
(
clip
.
get_sd
(),
vae
.
get_sd
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment