Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
de142eaa
"...composable_kernel_onnx.git" did not exist on "f584ab0c545ade05ae793a8b36fa282d47d0f698"
Commit
de142eaa
authored
Jun 09, 2023
by
comfyanonymous
Browse files
Simpler base model code.
parent
4b0b5165
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
164 additions
and
75 deletions
+164
-75
comfy/diffusers_load.py
comfy/diffusers_load.py
+2
-26
comfy/model_base.py
comfy/model_base.py
+66
-0
comfy/samplers.py
comfy/samplers.py
+41
-32
comfy/sd.py
comfy/sd.py
+55
-17
No files found.
comfy/diffusers_load.py
View file @
de142eaa
...
@@ -4,7 +4,7 @@ import yaml
...
@@ -4,7 +4,7 @@ import yaml
import
folder_paths
import
folder_paths
from
comfy.ldm.util
import
instantiate_from_config
from
comfy.ldm.util
import
instantiate_from_config
from
comfy.sd
import
ModelPatcher
,
load_model_weights
,
CLIP
,
VAE
from
comfy.sd
import
ModelPatcher
,
load_model_weights
,
CLIP
,
VAE
,
load_checkpoint
import
os.path
as
osp
import
os.path
as
osp
import
re
import
re
import
torch
import
torch
...
@@ -84,28 +84,4 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb
...
@@ -84,28 +84,4 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb
# Put together new checkpoint
# Put together new checkpoint
sd
=
{
**
unet_state_dict
,
**
vae_state_dict
,
**
text_enc_dict
}
sd
=
{
**
unet_state_dict
,
**
vae_state_dict
,
**
text_enc_dict
}
clip
=
None
return
load_checkpoint
(
embedding_directory
=
embedding_directory
,
state_dict
=
sd
,
config
=
config
)
vae
=
None
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
w
=
WeightsLoader
()
load_state_dict_to
=
[]
if
output_vae
:
vae
=
VAE
(
scale_factor
=
scale_factor
,
config
=
vae_config
)
w
.
first_stage_model
=
vae
.
first_stage_model
load_state_dict_to
=
[
w
]
if
output_clip
:
clip
=
CLIP
(
config
=
clip_config
,
embedding_directory
=
embedding_directory
)
w
.
cond_stage_model
=
clip
.
cond_stage_model
load_state_dict_to
=
[
w
]
model
=
instantiate_from_config
(
config
[
"model"
])
model
=
load_model_weights
(
model
,
sd
,
verbose
=
False
,
load_state_dict_to
=
load_state_dict_to
)
if
fp16
:
model
=
model
.
half
()
return
ModelPatcher
(
model
),
clip
,
vae
comfy/model_base.py
0 → 100644
View file @
de142eaa
import
torch
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.util
import
make_beta_schedule
import
numpy
as
np
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
unet_config
,
v_prediction
=
False
):
super
().
__init__
()
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
v_prediction
=
v_prediction
if
self
.
v_prediction
:
self
.
parameterization
=
"v"
else
:
self
.
parameterization
=
"eps"
if
"adm_in_channels"
in
unet_config
:
self
.
adm_channels
=
unet_config
[
"adm_in_channels"
]
else
:
self
.
adm_channels
=
0
print
(
"v_prediction"
,
v_prediction
)
print
(
"adm"
,
self
.
adm_channels
)
def
register_schedule
(
self
,
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
given_betas
is
not
None
:
betas
=
given_betas
else
:
betas
=
make_beta_schedule
(
beta_schedule
,
timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
,
cosine_s
=
cosine_s
)
alphas
=
1.
-
betas
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
np
.
append
(
1.
,
alphas_cumprod
[:
-
1
])
timesteps
,
=
betas
.
shape
self
.
num_timesteps
=
int
(
timesteps
)
self
.
linear_start
=
linear_start
self
.
linear_end
=
linear_end
self
.
register_buffer
(
'betas'
,
torch
.
tensor
(
betas
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'alphas_cumprod'
,
torch
.
tensor
(
alphas_cumprod
,
dtype
=
torch
.
float32
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
torch
.
tensor
(
alphas_cumprod_prev
,
dtype
=
torch
.
float32
))
def
apply_model
(
self
,
x
,
t
,
c_concat
=
None
,
c_crossattn
=
None
,
c_adm
=
None
,
control
=
None
,
transformer_options
=
{}):
if
c_concat
is
not
None
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
else
:
xc
=
x
context
=
torch
.
cat
(
c_crossattn
,
1
)
return
self
.
diffusion_model
(
xc
,
t
,
context
=
context
,
y
=
c_adm
,
control
=
control
,
transformer_options
=
transformer_options
)
def
get_dtype
(
self
):
return
self
.
diffusion_model
.
dtype
def
is_adm
(
self
):
return
self
.
adm_channels
>
0
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
unet_config
,
noise_aug_config
,
v_prediction
=
True
):
super
().
__init__
(
unet_config
,
v_prediction
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
unet_config
,
v_prediction
=
False
):
super
().
__init__
(
unet_config
,
v_prediction
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
comfy/samplers.py
View file @
de142eaa
...
@@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
...
@@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
c
[
'transformer_options'
]
=
transformer_options
c
[
'transformer_options'
]
=
transformer_options
output
=
model_function
(
input_x
,
timestep_
,
cond
=
c
).
chunk
(
batch_chunks
)
output
=
model_function
(
input_x
,
timestep_
,
**
c
).
chunk
(
batch_chunks
)
del
input_x
del
input_x
model_management
.
throw_exception_if_processing_interrupted
()
model_management
.
throw_exception_if_processing_interrupted
()
...
@@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
...
@@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
uncond
[
temp
[
1
]]
=
[
o
[
0
],
n
]
uncond
[
temp
[
1
]]
=
[
o
[
0
],
n
]
def
encode_adm
(
noise_augmentor
,
conds
,
batch_size
,
device
):
def
encode_adm
(
conds
,
batch_size
,
device
,
noise_augmentor
=
None
):
for
t
in
range
(
len
(
conds
)):
for
t
in
range
(
len
(
conds
)):
x
=
conds
[
t
]
x
=
conds
[
t
]
if
'adm'
in
x
[
1
]:
adm_out
=
None
adm_inputs
=
[]
if
noise_augmentor
is
not
None
:
weights
=
[]
if
'adm'
in
x
[
1
]:
noise_aug
=
[]
adm_inputs
=
[]
adm_in
=
x
[
1
][
"adm"
]
weights
=
[]
for
adm_c
in
adm_in
:
noise_aug
=
[]
adm_cond
=
adm_c
[
0
].
image_embeds
adm_in
=
x
[
1
][
"adm"
]
weight
=
adm_c
[
1
]
for
adm_c
in
adm_in
:
noise_augment
=
adm_c
[
2
]
adm_cond
=
adm_c
[
0
].
image_embeds
noise_level
=
round
((
noise_augmentor
.
max_noise_level
-
1
)
*
noise_augment
)
weight
=
adm_c
[
1
]
c_adm
,
noise_level_emb
=
noise_augmentor
(
adm_cond
.
to
(
device
),
noise_level
=
torch
.
tensor
([
noise_level
],
device
=
device
))
noise_augment
=
adm_c
[
2
]
adm_out
=
torch
.
cat
((
c_adm
,
noise_level_emb
),
1
)
*
weight
noise_level
=
round
((
noise_augmentor
.
max_noise_level
-
1
)
*
noise_augment
)
weights
.
append
(
weight
)
c_adm
,
noise_level_emb
=
noise_augmentor
(
adm_cond
.
to
(
device
),
noise_level
=
torch
.
tensor
([
noise_level
],
device
=
device
))
noise_aug
.
append
(
noise_augment
)
adm_out
=
torch
.
cat
((
c_adm
,
noise_level_emb
),
1
)
*
weight
adm_inputs
.
append
(
adm_out
)
weights
.
append
(
weight
)
noise_aug
.
append
(
noise_augment
)
if
len
(
noise_aug
)
>
1
:
adm_inputs
.
append
(
adm_out
)
adm_out
=
torch
.
stack
(
adm_inputs
).
sum
(
0
)
#TODO: add a way to control this
if
len
(
noise_aug
)
>
1
:
noise_augment
=
0.05
adm_out
=
torch
.
stack
(
adm_inputs
).
sum
(
0
)
noise_level
=
round
((
noise_augmentor
.
max_noise_level
-
1
)
*
noise_augment
)
#TODO: add a way to control this
c_adm
,
noise_level_emb
=
noise_augmentor
(
adm_out
[:,
:
noise_augmentor
.
time_embed
.
dim
],
noise_level
=
torch
.
tensor
([
noise_level
],
device
=
device
))
noise_augment
=
0.05
adm_out
=
torch
.
cat
((
c_adm
,
noise_level_emb
),
1
)
noise_level
=
round
((
noise_augmentor
.
max_noise_level
-
1
)
*
noise_augment
)
c_adm
,
noise_level_emb
=
noise_augmentor
(
adm_out
[:,
:
noise_augmentor
.
time_embed
.
dim
],
noise_level
=
torch
.
tensor
([
noise_level
],
device
=
device
))
adm_out
=
torch
.
cat
((
c_adm
,
noise_level_emb
),
1
)
else
:
adm_out
=
torch
.
zeros
((
1
,
noise_augmentor
.
time_embed
.
dim
*
2
),
device
=
device
)
else
:
else
:
adm_out
=
torch
.
zeros
((
1
,
noise_augmentor
.
time_embed
.
dim
*
2
),
device
=
device
)
if
'adm'
in
x
[
1
]:
x
[
1
]
=
x
[
1
].
copy
()
adm_out
=
x
[
1
][
"adm"
].
to
(
device
)
x
[
1
][
"adm_encoded"
]
=
torch
.
cat
([
adm_out
]
*
batch_size
)
if
adm_out
is
not
None
:
x
[
1
]
=
x
[
1
].
copy
()
x
[
1
][
"adm_encoded"
]
=
torch
.
cat
([
adm_out
]
*
batch_size
)
return
conds
return
conds
...
@@ -591,14 +597,17 @@ class KSampler:
...
@@ -591,14 +597,17 @@ class KSampler:
apply_empty_x_to_equal_area
(
positive
,
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'control'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
apply_empty_x_to_equal_area
(
positive
,
negative
,
'gligen'
,
lambda
cond_cnets
,
x
:
cond_cnets
[
x
])
if
self
.
model
.
model
.
diffusion_model
.
dtype
==
torch
.
float16
:
if
self
.
model
.
get_
dtype
()
==
torch
.
float16
:
precision_scope
=
torch
.
autocast
precision_scope
=
torch
.
autocast
else
:
else
:
precision_scope
=
contextlib
.
nullcontext
precision_scope
=
contextlib
.
nullcontext
if
hasattr
(
self
.
model
,
'noise_augmentor'
):
#unclip
if
self
.
model
.
is_adm
():
positive
=
encode_adm
(
self
.
model
.
noise_augmentor
,
positive
,
noise
.
shape
[
0
],
self
.
device
)
noise_augmentor
=
None
negative
=
encode_adm
(
self
.
model
.
noise_augmentor
,
negative
,
noise
.
shape
[
0
],
self
.
device
)
if
hasattr
(
self
.
model
,
'noise_augmentor'
):
#unclip
noise_augmentor
=
self
.
model
.
noise_augmentor
positive
=
encode_adm
(
positive
,
noise
.
shape
[
0
],
self
.
device
,
noise_augmentor
)
negative
=
encode_adm
(
negative
,
noise
.
shape
[
0
],
self
.
device
,
noise_augmentor
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
self
.
model_options
}
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
self
.
model_options
}
...
...
comfy/sd.py
View file @
de142eaa
...
@@ -15,8 +15,15 @@ from . import utils
...
@@ -15,8 +15,15 @@ from . import utils
from
.
import
clip_vision
from
.
import
clip_vision
from
.
import
gligen
from
.
import
gligen
from
.
import
diffusers_convert
from
.
import
diffusers_convert
from
.
import
model_base
def
load_model_weights
(
model
,
sd
,
verbose
=
False
,
load_state_dict_to
=
[]):
def
load_model_weights
(
model
,
sd
,
verbose
=
False
,
load_state_dict_to
=
[]):
replace_prefix
=
{
"model.diffusion_model."
:
"diffusion_model."
}
for
rp
in
replace_prefix
:
replace
=
list
(
map
(
lambda
a
:
(
a
,
"{}{}"
.
format
(
replace_prefix
[
rp
],
a
[
len
(
rp
):])),
filter
(
lambda
a
:
a
.
startswith
(
rp
),
sd
.
keys
())))
for
x
in
replace
:
sd
[
x
[
1
]]
=
sd
.
pop
(
x
[
0
])
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
m
,
u
=
model
.
load_state_dict
(
sd
,
strict
=
False
)
k
=
list
(
sd
.
keys
())
k
=
list
(
sd
.
keys
())
...
@@ -182,7 +189,7 @@ def model_lora_keys(model, key_map={}):
...
@@ -182,7 +189,7 @@ def model_lora_keys(model, key_map={}):
counter
=
0
counter
=
0
for
b
in
range
(
12
):
for
b
in
range
(
12
):
tk
=
"
model.
diffusion_model.input_blocks.{}.1"
.
format
(
b
)
tk
=
"diffusion_model.input_blocks.{}.1"
.
format
(
b
)
up_counter
=
0
up_counter
=
0
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
...
@@ -193,13 +200,13 @@ def model_lora_keys(model, key_map={}):
...
@@ -193,13 +200,13 @@ def model_lora_keys(model, key_map={}):
if
up_counter
>=
4
:
if
up_counter
>=
4
:
counter
+=
1
counter
+=
1
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
k
=
"
model.
diffusion_model.middle_block.1.{}.weight"
.
format
(
c
)
k
=
"diffusion_model.middle_block.1.{}.weight"
.
format
(
c
)
if
k
in
sdk
:
if
k
in
sdk
:
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP_ATTENTIONS
[
c
])
lora_key
=
"lora_unet_mid_block_attentions_0_{}"
.
format
(
LORA_UNET_MAP_ATTENTIONS
[
c
])
key_map
[
lora_key
]
=
k
key_map
[
lora_key
]
=
k
counter
=
3
counter
=
3
for
b
in
range
(
12
):
for
b
in
range
(
12
):
tk
=
"
model.
diffusion_model.output_blocks.{}.1"
.
format
(
b
)
tk
=
"diffusion_model.output_blocks.{}.1"
.
format
(
b
)
up_counter
=
0
up_counter
=
0
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
for
c
in
LORA_UNET_MAP_ATTENTIONS
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
...
@@ -223,7 +230,7 @@ def model_lora_keys(model, key_map={}):
...
@@ -223,7 +230,7 @@ def model_lora_keys(model, key_map={}):
ds_counter
=
0
ds_counter
=
0
counter
=
0
counter
=
0
for
b
in
range
(
12
):
for
b
in
range
(
12
):
tk
=
"
model.
diffusion_model.input_blocks.{}.0"
.
format
(
b
)
tk
=
"diffusion_model.input_blocks.{}.0"
.
format
(
b
)
key_in
=
False
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
...
@@ -242,7 +249,7 @@ def model_lora_keys(model, key_map={}):
...
@@ -242,7 +249,7 @@ def model_lora_keys(model, key_map={}):
counter
=
0
counter
=
0
for
b
in
range
(
3
):
for
b
in
range
(
3
):
tk
=
"
model.
diffusion_model.middle_block.{}"
.
format
(
b
)
tk
=
"diffusion_model.middle_block.{}"
.
format
(
b
)
key_in
=
False
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
...
@@ -256,7 +263,7 @@ def model_lora_keys(model, key_map={}):
...
@@ -256,7 +263,7 @@ def model_lora_keys(model, key_map={}):
counter
=
0
counter
=
0
us_counter
=
0
us_counter
=
0
for
b
in
range
(
12
):
for
b
in
range
(
12
):
tk
=
"
model.
diffusion_model.output_blocks.{}.0"
.
format
(
b
)
tk
=
"diffusion_model.output_blocks.{}.0"
.
format
(
b
)
key_in
=
False
key_in
=
False
for
c
in
LORA_UNET_MAP_RESNET
:
for
c
in
LORA_UNET_MAP_RESNET
:
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
k
=
"{}.{}.weight"
.
format
(
tk
,
c
)
...
@@ -332,7 +339,7 @@ class ModelPatcher:
...
@@ -332,7 +339,7 @@ class ModelPatcher:
patch_list
[
i
]
=
patch_list
[
i
].
to
(
device
)
patch_list
[
i
]
=
patch_list
[
i
].
to
(
device
)
def
model_dtype
(
self
):
def
model_dtype
(
self
):
return
self
.
model
.
diffusion_model
.
dtype
return
self
.
model
.
get_
dtype
()
def
add_patches
(
self
,
patches
,
strength
=
1.0
):
def
add_patches
(
self
,
patches
,
strength
=
1.0
):
p
=
{}
p
=
{}
...
@@ -764,7 +771,7 @@ def load_controlnet(ckpt_path, model=None):
...
@@ -764,7 +771,7 @@ def load_controlnet(ckpt_path, model=None):
for
x
in
controlnet_data
:
for
x
in
controlnet_data
:
c_m
=
"control_model."
c_m
=
"control_model."
if
x
.
startswith
(
c_m
):
if
x
.
startswith
(
c_m
):
sd_key
=
"
model.
diffusion_model.{}"
.
format
(
x
[
len
(
c_m
):])
sd_key
=
"diffusion_model.{}"
.
format
(
x
[
len
(
c_m
):])
if
sd_key
in
model_sd
:
if
sd_key
in
model_sd
:
cd
=
controlnet_data
[
x
]
cd
=
controlnet_data
[
x
]
cd
+=
model_sd
[
sd_key
].
type
(
cd
.
dtype
).
to
(
cd
.
device
)
cd
+=
model_sd
[
sd_key
].
type
(
cd
.
dtype
).
to
(
cd
.
device
)
...
@@ -931,9 +938,10 @@ def load_gligen(ckpt_path):
...
@@ -931,9 +938,10 @@ def load_gligen(ckpt_path):
model
=
model
.
half
()
model
=
model
.
half
()
return
model
return
model
def
load_checkpoint
(
config_path
,
ckpt_path
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
None
):
def
load_checkpoint
(
config_path
=
None
,
ckpt_path
=
None
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
None
,
state_dict
=
None
,
config
=
None
):
with
open
(
config_path
,
'r'
)
as
stream
:
if
config
is
None
:
config
=
yaml
.
safe_load
(
stream
)
with
open
(
config_path
,
'r'
)
as
stream
:
config
=
yaml
.
safe_load
(
stream
)
model_config_params
=
config
[
'model'
][
'params'
]
model_config_params
=
config
[
'model'
][
'params'
]
clip_config
=
model_config_params
[
'cond_stage_config'
]
clip_config
=
model_config_params
[
'cond_stage_config'
]
scale_factor
=
model_config_params
[
'scale_factor'
]
scale_factor
=
model_config_params
[
'scale_factor'
]
...
@@ -942,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
...
@@ -942,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
fp16
=
False
fp16
=
False
if
"unet_config"
in
model_config_params
:
if
"unet_config"
in
model_config_params
:
if
"params"
in
model_config_params
[
"unet_config"
]:
if
"params"
in
model_config_params
[
"unet_config"
]:
if
"use_fp16"
in
model_config_params
[
"unet_config"
][
"params"
]:
unet_config
=
model_config_params
[
"unet_config"
][
"params"
]
fp16
=
model_config_params
[
"unet_config"
][
"params"
][
"use_fp16"
]
if
"use_fp16"
in
unet_config
:
fp16
=
unet_config
[
"use_fp16"
]
noise_aug_config
=
None
if
"noise_aug_config"
in
model_config_params
:
noise_aug_config
=
model_config_params
[
"noise_aug_config"
]
v_prediction
=
False
if
"parameterization"
in
model_config_params
:
if
model_config_params
[
"parameterization"
]
==
"v"
:
v_prediction
=
True
clip
=
None
clip
=
None
vae
=
None
vae
=
None
...
@@ -963,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
...
@@ -963,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
w
.
cond_stage_model
=
clip
.
cond_stage_model
w
.
cond_stage_model
=
clip
.
cond_stage_model
load_state_dict_to
=
[
w
]
load_state_dict_to
=
[
w
]
model
=
instantiate_from_config
(
config
[
"model"
])
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
sd
=
utils
.
load_torch_file
(
ckpt_path
)
model
=
model_base
.
SDInpaint
(
unet_config
,
v_prediction
=
v_prediction
)
model
=
load_model_weights
(
model
,
sd
,
verbose
=
False
,
load_state_dict_to
=
load_state_dict_to
)
elif
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
unet_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
else
:
model
=
model_base
.
BaseModel
(
unet_config
,
v_prediction
=
v_prediction
)
if
state_dict
is
None
:
state_dict
=
utils
.
load_torch_file
(
ckpt_path
)
model
=
load_model_weights
(
model
,
state_dict
,
verbose
=
False
,
load_state_dict_to
=
load_state_dict_to
)
if
fp16
:
if
fp16
:
model
=
model
.
half
()
model
=
model
.
half
()
...
@@ -1073,16 +1099,20 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -1073,16 +1099,20 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
sd_config
[
"unet_config"
]
=
{
"target"
:
"comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel"
,
"params"
:
unet_config
}
sd_config
[
"unet_config"
]
=
{
"target"
:
"comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel"
,
"params"
:
unet_config
}
model_config
=
{
"target"
:
"comfy.ldm.models.diffusion.ddpm.LatentDiffusion"
,
"params"
:
sd_config
}
model_config
=
{
"target"
:
"comfy.ldm.models.diffusion.ddpm.LatentDiffusion"
,
"params"
:
sd_config
}
unclip_model
=
False
inpaint_model
=
False
if
noise_aug_config
is
not
None
:
#SD2.x unclip model
if
noise_aug_config
is
not
None
:
#SD2.x unclip model
sd_config
[
"noise_aug_config"
]
=
noise_aug_config
sd_config
[
"noise_aug_config"
]
=
noise_aug_config
sd_config
[
"image_size"
]
=
96
sd_config
[
"image_size"
]
=
96
sd_config
[
"embedding_dropout"
]
=
0.25
sd_config
[
"embedding_dropout"
]
=
0.25
sd_config
[
"conditioning_key"
]
=
'crossattn-adm'
sd_config
[
"conditioning_key"
]
=
'crossattn-adm'
unclip_model
=
True
model_config
[
"target"
]
=
"comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
model_config
[
"target"
]
=
"comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion"
elif
unet_config
[
"in_channels"
]
>
4
:
#inpainting model
elif
unet_config
[
"in_channels"
]
>
4
:
#inpainting model
sd_config
[
"conditioning_key"
]
=
"hybrid"
sd_config
[
"conditioning_key"
]
=
"hybrid"
sd_config
[
"finetune_keys"
]
=
None
sd_config
[
"finetune_keys"
]
=
None
model_config
[
"target"
]
=
"comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
model_config
[
"target"
]
=
"comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
inpaint_model
=
True
else
:
else
:
sd_config
[
"conditioning_key"
]
=
"crossattn"
sd_config
[
"conditioning_key"
]
=
"crossattn"
...
@@ -1096,13 +1126,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -1096,13 +1126,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
unet_config
[
"num_classes"
]
=
"sequential"
unet_config
[
"num_classes"
]
=
"sequential"
unet_config
[
"adm_in_channels"
]
=
sd
[
unclip
].
shape
[
1
]
unet_config
[
"adm_in_channels"
]
=
sd
[
unclip
].
shape
[
1
]
v_prediction
=
False
if
unet_config
[
"context_dim"
]
==
1024
and
unet_config
[
"in_channels"
]
==
4
:
#only SD2.x non inpainting models are v prediction
if
unet_config
[
"context_dim"
]
==
1024
and
unet_config
[
"in_channels"
]
==
4
:
#only SD2.x non inpainting models are v prediction
k
=
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
k
=
"model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias"
out
=
sd
[
k
]
out
=
sd
[
k
]
if
torch
.
std
(
out
,
unbiased
=
False
)
>
0.09
:
# not sure how well this will actually work. I guess we will find out.
if
torch
.
std
(
out
,
unbiased
=
False
)
>
0.09
:
# not sure how well this will actually work. I guess we will find out.
v_prediction
=
True
sd_config
[
"parameterization"
]
=
'v'
sd_config
[
"parameterization"
]
=
'v'
model
=
instantiate_from_config
(
model_config
)
if
inpaint_model
:
model
=
model_base
.
SDInpaint
(
unet_config
,
v_prediction
=
v_prediction
)
elif
unclip_model
:
model
=
model_base
.
SD21UNCLIP
(
unet_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
else
:
model
=
model_base
.
BaseModel
(
unet_config
,
v_prediction
=
v_prediction
)
model
=
load_model_weights
(
model
,
sd
,
verbose
=
False
,
load_state_dict_to
=
load_state_dict_to
)
model
=
load_model_weights
(
model
,
sd
,
verbose
=
False
,
load_state_dict_to
=
load_state_dict_to
)
if
fp16
:
if
fp16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment