Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
7931ff0f
Commit
7931ff0f
authored
Sep 01, 2023
by
comfyanonymous
Browse files
Support SDXL inpaint models.
parent
c335fdf2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
16 deletions
+22
-16
comfy/model_base.py
comfy/model_base.py
+3
-6
comfy/model_detection.py
comfy/model_detection.py
+5
-1
comfy/sd.py
comfy/sd.py
+4
-3
comfy/supported_models.py
comfy/supported_models.py
+4
-1
comfy/supported_models_base.py
comfy/supported_models_base.py
+6
-5
No files found.
comfy/model_base.py
View file @
7931ff0f
...
...
@@ -111,6 +111,9 @@ class BaseModel(torch.nn.Module):
return
{
**
unet_state_dict
,
**
vae_state_dict
,
**
clip_state_dict
}
def
set_inpaint
(
self
):
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
def
unclip_adm
(
unclip_conditioning
,
device
,
noise_augmentor
,
noise_augment_merge
=
0.0
):
adm_inputs
=
[]
weights
=
[]
...
...
@@ -148,12 +151,6 @@ class SD21UNCLIP(BaseModel):
else
:
return
unclip_adm
(
unclip_conditioning
,
device
,
self
.
noise_augmentor
,
kwargs
.
get
(
"unclip_noise_augment_merge"
,
0.05
))
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
def
sdxl_pooled
(
args
,
noise_augmentor
):
if
"unclip_conditioning"
in
args
:
return
unclip_adm
(
args
.
get
(
"unclip_conditioning"
,
None
),
args
[
"device"
],
noise_augmentor
)[:,:
1280
]
...
...
comfy/model_detection.py
View file @
7931ff0f
...
...
@@ -183,8 +183,12 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[],
'transformer_depth'
:
[
0
,
0
,
0
],
'channel_mult'
:
[
1
,
2
,
4
],
'transformer_depth_middle'
:
0
,
'use_linear_in_transformer'
:
True
,
"num_head_channels"
:
64
,
'context_dim'
:
1
}
SDXL_diffusers_inpaint
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
2816
,
'use_fp16'
:
use_fp16
,
'in_channels'
:
9
,
'model_channels'
:
320
,
'num_res_blocks'
:
2
,
'attention_resolutions'
:
[
2
,
4
],
'transformer_depth'
:
[
0
,
2
,
10
],
'channel_mult'
:
[
1
,
2
,
4
],
'transformer_depth_middle'
:
10
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
2048
,
"num_head_channels"
:
64
}
supported_models
=
[
SDXL
,
SDXL_refiner
,
SD21
,
SD15
,
SD21_uncliph
,
SD21_unclipl
,
SDXL_mid_cnet
,
SDXL_small_cnet
]
supported_models
=
[
SDXL
,
SDXL_refiner
,
SD21
,
SD15
,
SD21_uncliph
,
SD21_unclipl
,
SDXL_mid_cnet
,
SDXL_small_cnet
,
SDXL_diffusers_inpaint
]
for
unet_config
in
supported_models
:
matches
=
True
...
...
comfy/sd.py
View file @
7931ff0f
...
...
@@ -355,13 +355,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
model_config
.
unet_config
=
unet_config
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
model
=
model_base
.
SDInpaint
(
model_config
,
model_type
=
model_type
)
elif
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
if
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
model_type
=
model_type
)
else
:
model
=
model_base
.
BaseModel
(
model_config
,
model_type
=
model_type
)
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
model
.
set_inpaint
()
if
fp16
:
model
=
model
.
half
()
...
...
comfy/supported_models.py
View file @
7931ff0f
...
...
@@ -153,7 +153,10 @@ class SDXL(supported_models_base.BASE):
return
model_base
.
ModelType
.
EPS
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
return
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
),
device
=
device
)
out
=
model_base
.
SDXL
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
),
device
=
device
)
if
self
.
inpaint_model
():
out
.
set_inpaint
()
return
out
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
7931ff0f
...
...
@@ -57,12 +57,13 @@ class BASE:
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
),
device
=
device
)
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
),
device
=
device
)
if
self
.
noise_aug_config
is
not
None
:
out
=
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
),
device
=
device
)
else
:
return
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
),
device
=
device
)
out
=
model_base
.
BaseModel
(
self
,
model_type
=
self
.
model_type
(
state_dict
,
prefix
),
device
=
device
)
if
self
.
inpaint_model
():
out
.
set_inpaint
()
return
out
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment