Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
94a5a67c
"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "86eafdf5e417eaa6b03f06e388e527d32e90494b"
Commit
94a5a67c
authored
Mar 29, 2024
by
comfyanonymous
Browse files
Cleanup to support different types of inpaint models.
parent
9bf6061d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
comfy/model_base.py
comfy/model_base.py
+22
-21
No files found.
comfy/model_base.py
View file @
94a5a67c
...
...
@@ -66,7 +66,8 @@ class BaseModel(torch.nn.Module):
self
.
adm_channels
=
unet_config
.
get
(
"adm_in_channels"
,
None
)
if
self
.
adm_channels
is
None
:
self
.
adm_channels
=
0
self
.
inpaint_model
=
False
self
.
concat_keys
=
()
logging
.
info
(
"model_type {}"
.
format
(
model_type
.
name
))
logging
.
debug
(
"adm {}"
.
format
(
self
.
adm_channels
))
...
...
@@ -107,8 +108,7 @@ class BaseModel(torch.nn.Module):
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
if
self
.
inpaint_model
:
concat_keys
=
(
"mask"
,
"masked_image"
)
if
len
(
self
.
concat_keys
)
>
0
:
cond_concat
=
[]
denoise_mask
=
kwargs
.
get
(
"concat_mask"
,
kwargs
.
get
(
"denoise_mask"
,
None
))
concat_latent_image
=
kwargs
.
get
(
"concat_latent_image"
,
None
)
...
...
@@ -125,24 +125,16 @@ class BaseModel(torch.nn.Module):
concat_latent_image
=
utils
.
resize_to_batch_size
(
concat_latent_image
,
noise
.
shape
[
0
])
if
len
(
denoise_mask
.
shape
)
==
len
(
noise
.
shape
):
denoise_mask
=
denoise_mask
[:,:
1
]
denoise_mask
=
denoise_mask
.
reshape
((
-
1
,
1
,
denoise_mask
.
shape
[
-
2
],
denoise_mask
.
shape
[
-
1
]))
if
denoise_mask
.
shape
[
-
2
:]
!=
noise
.
shape
[
-
2
:]:
denoise_mask
=
utils
.
common_upscale
(
denoise_mask
,
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
denoise_mask
=
utils
.
resize_to_batch_size
(
denoise_mask
.
round
(),
noise
.
shape
[
0
])
if
denoise_mask
is
not
None
:
if
len
(
denoise_mask
.
shape
)
==
len
(
noise
.
shape
):
denoise_mask
=
denoise_mask
[:,:
1
]
def
blank_inpaint_image_like
(
latent_image
):
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
2
]
*=
0.6364
blank_image
[:,
3
]
*=
0.1380
return
blank_image
denoise_mask
=
denoise_mask
.
reshape
((
-
1
,
1
,
denoise_mask
.
shape
[
-
2
],
denoise_mask
.
shape
[
-
1
]))
if
denoise_mask
.
shape
[
-
2
:]
!=
noise
.
shape
[
-
2
:]:
denoise_mask
=
utils
.
common_upscale
(
denoise_mask
,
noise
.
shape
[
-
1
],
noise
.
shape
[
-
2
],
"bilinear"
,
"center"
)
denoise_mask
=
utils
.
resize_to_batch_size
(
denoise_mask
.
round
(),
noise
.
shape
[
0
])
for
ck
in
concat_keys
:
for
ck
in
self
.
concat_keys
:
if
denoise_mask
is
not
None
:
if
ck
==
"mask"
:
cond_concat
.
append
(
denoise_mask
.
to
(
device
))
...
...
@@ -152,7 +144,7 @@ class BaseModel(torch.nn.Module):
if
ck
==
"mask"
:
cond_concat
.
append
(
torch
.
ones_like
(
noise
)[:,:
1
])
elif
ck
==
"masked_image"
:
cond_concat
.
append
(
blank_inpaint_image_like
(
noise
))
cond_concat
.
append
(
self
.
blank_inpaint_image_like
(
noise
))
data
=
torch
.
cat
(
cond_concat
,
dim
=
1
)
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
data
)
...
...
@@ -221,7 +213,16 @@ class BaseModel(torch.nn.Module):
return
unet_state_dict
def
set_inpaint
(
self
):
self
.
inpaint_model
=
True
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
def
blank_inpaint_image_like
(
latent_image
):
blank_image
=
torch
.
ones_like
(
latent_image
)
# these are the values for "zero" in pixel space translated to latent space
blank_image
[:,
0
]
*=
0.8223
blank_image
[:,
1
]
*=
-
0.6876
blank_image
[:,
2
]
*=
0.6364
blank_image
[:,
3
]
*=
0.1380
return
blank_image
self
.
blank_inpaint_image_like
=
blank_inpaint_image_like
def
memory_required
(
self
,
input_shape
):
if
comfy
.
model_management
.
xformers_enabled
()
or
comfy
.
model_management
.
pytorch_attention_flash_attention
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment