Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
94a5a67c
Commit
94a5a67c
authored
Mar 29, 2024
by
comfyanonymous
Browse files
Cleanup to support different types of inpaint models.
parent
9bf6061d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
comfy/model_base.py
comfy/model_base.py
+22
-21
No files found.
comfy/model_base.py
View file @
94a5a67c
...
@@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module):
...
@@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module):
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
if
self
.
adm_channels
is
None
:
if
self
.
adm_channels
is
None
:
self
.
adm_channels
=
0
self
.
adm_channels
=
0
self
.
inpaint_model
=
False
self
.
concat_keys
=
()
logging
.
info
(
"model_type {}"
.
format
(
model_type
.
name
))
logging
.
info
(
"model_type {}"
.
format
(
model_type
.
name
))
logging
.
debug
(
"adm {}"
.
format
(
self
.
adm_channels
))
logging
.
debug
(
"adm {}"
.
format
(
self
.
adm_channels
))
...
@@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module):
...
@@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module):
def
extra_conds
(
self
,
**
kwargs
):
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
out
=
{}
if
self
.
inpaint_model
:
if
len
(
self
.
concat_keys
)
>
0
:
concat_keys
=
(
"mask"
,
"masked_image"
)
cond_concat
=
[]
cond_concat
=
[]
denoise_mask
=
kwargs
.
get
(
"concat_mask"
,
kwargs
.
get
(
"denoise_mask"
,
None
))
denoise_mask
=
kwargs
.
get
(
"concat_mask"
,
kwargs
.
get
(
"denoise_mask"
,
None
))
concat_latent_image
=
kwargs
.
get
(
"concat_latent_image"
,
None
)
concat_latent_image
=
kwargs
.
get
(
"concat_latent_image"
,
None
)
...
@@ -125,6 +125,7 @@ class BaseModel(torch.nn.Module):
...
@@ -125,6 +125,7 @@ class BaseModel(torch.nn.Module):
concat_latent_image
=
utils
.
resize_to_batch_size
(
concat_latent_image
,
noise
.
shape
[
0
])
concat_latent_image
=
utils
.
resize_to_batch_size
(
concat_latent_image
,
noise
.
shape
[
0
])
if
denoise_mask
is
not
None
:
if
len
(
denoise_mask
.
shape
)
==
len
(
noise
.
shape
):
if
len
(
denoise_mask
.
shape
)
==
len
(
noise
.
shape
):
denoise_mask
=
denoise_mask
[:,:
1
]
denoise_mask
=
denoise_mask
[:,:
1
]
...
@@ -133,16 +134,7 @@ class BaseModel(torch.nn.Module):
...
@@ -133,16 +134,7 @@ class BaseModel(torch.nn.Module):
denoise_mask
=
utils
.
common_upscale
(
denoise_mask
,
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
denoise_mask
=
utils
.
common_upscale
(
denoise_mask
,
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
denoise_mask
=
utils
.
resize_to_batch_size
(
denoise_mask
.
round
(),
noise
.
shape
[
0
])
denoise_mask
=
utils
.
resize_to_batch_size
(
denoise_mask
.
round
(),
noise
.
shape
[
0
])
def
blank_inpaint_image_like
(
latent_image
):
for
ck
in
self
.
concat_keys
:
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
2
]
*=
0.6364
blank_image
[:,
3
]
*=
0.1380
return
blank_image
for
ck
in
concat_keys
:
if
denoise_mask
is
not
None
:
if
denoise_mask
is
not
None
:
if
ck
==
"mask"
:
if
ck
==
"mask"
:
cond_concat
.
append
(
denoise_mask
.
to
(
device
))
cond_concat
.
append
(
denoise_mask
.
to
(
device
))
...
@@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module):
...
@@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module):
if
ck
==
"mask"
:
if
ck
==
"mask"
:
cond_concat
.
append
(
torch
.
ones_like
(
noise
)[:,:
1
])
cond_concat
.
append
(
torch
.
ones_like
(
noise
)[:,:
1
])
elif
ck
==
"masked_image"
:
elif
ck
==
"masked_image"
:
cond_concat
.
append
(
blank_inpaint_image_like
(
noise
))
cond_concat
.
append
(
self
.
blank_inpaint_image_like
(
noise
))
data
=
torch
.
cat
(
cond_concat
,
dim
=
1
)
data
=
torch
.
cat
(
cond_concat
,
dim
=
1
)
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
data
)
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
data
)
...
@@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module):
...
@@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module):
return
unet_state_dict
return
unet_state_dict
def
set_inpaint
(
self
):
def
set_inpaint
(
self
):
self
.
inpaint_model
=
True
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
def
blank_inpaint_image_like
(
latent_image
):
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
2
]
*=
0.6364
blank_image
[:,
3
]
*=
0.1380
return
blank_image
self
.
blank_inpaint_image_like
=
blank_inpaint_image_like
def
memory_required
(
self
,
input_shape
):
def
memory_required
(
self
,
input_shape
):
if
comfy
.
model_management
.
xformers_enabled
()
or
comfy
.
model_management
.
pytorch_attention_flash_attention
():
if
comfy
.
model_management
.
xformers_enabled
()
or
comfy
.
model_management
.
pytorch_attention_flash_attention
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment