Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
d2cec6cd
"vscode:/vscode.git/clone" did not exist on "968078b149936c8e297c48ddb97673f16e467b83"
Commit
d2cec6cd
authored
Sep 25, 2023
by
comfyanonymous
Browse files
Make mask functions work with batches of masks and images.
parent
046b4fe0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
17 deletions
+19
-17
comfy_extras/nodes_mask.py
comfy_extras/nodes_mask.py
+19
-17
No files found.
comfy_extras/nodes_mask.py
View file @
d2cec6cd
...
@@ -144,8 +144,8 @@ class ImageColorToMask:
...
@@ -144,8 +144,8 @@ class ImageColorToMask:
FUNCTION
=
"image_to_mask"
FUNCTION
=
"image_to_mask"
def
image_to_mask
(
self
,
image
,
color
):
def
image_to_mask
(
self
,
image
,
color
):
temp
=
(
torch
.
clamp
(
image
[
0
]
,
0
,
1.0
)
*
255.0
).
round
().
to
(
torch
.
int
)
temp
=
(
torch
.
clamp
(
image
,
0
,
1.0
)
*
255.0
).
round
().
to
(
torch
.
int
)
temp
=
torch
.
bitwise_left_shift
(
temp
[:,:,
0
],
16
)
+
torch
.
bitwise_left_shift
(
temp
[:,:,
1
],
8
)
+
temp
[:,:,
2
]
temp
=
torch
.
bitwise_left_shift
(
temp
[:,:,
:,
0
],
16
)
+
torch
.
bitwise_left_shift
(
temp
[:,:,
:,
1
],
8
)
+
temp
[:,:,
:,
2
]
mask
=
torch
.
where
(
temp
==
color
,
255
,
0
).
float
()
mask
=
torch
.
where
(
temp
==
color
,
255
,
0
).
float
()
return
(
mask
,)
return
(
mask
,)
...
@@ -167,7 +167,7 @@ class SolidMask:
...
@@ -167,7 +167,7 @@ class SolidMask:
FUNCTION
=
"solid"
FUNCTION
=
"solid"
def
solid
(
self
,
value
,
width
,
height
):
def
solid
(
self
,
value
,
width
,
height
):
out
=
torch
.
full
((
height
,
width
),
value
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
out
=
torch
.
full
((
1
,
height
,
width
),
value
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
(
out
,)
return
(
out
,)
class
InvertMask
:
class
InvertMask
:
...
@@ -209,7 +209,8 @@ class CropMask:
...
@@ -209,7 +209,8 @@ class CropMask:
FUNCTION
=
"crop"
FUNCTION
=
"crop"
def
crop
(
self
,
mask
,
x
,
y
,
width
,
height
):
def
crop
(
self
,
mask
,
x
,
y
,
width
,
height
):
out
=
mask
[
y
:
y
+
height
,
x
:
x
+
width
]
mask
=
mask
.
reshape
((
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]))
out
=
mask
[:,
y
:
y
+
height
,
x
:
x
+
width
]
return
(
out
,)
return
(
out
,)
class
MaskComposite
:
class
MaskComposite
:
...
@@ -232,27 +233,28 @@ class MaskComposite:
...
@@ -232,27 +233,28 @@ class MaskComposite:
FUNCTION
=
"combine"
FUNCTION
=
"combine"
def
combine
(
self
,
destination
,
source
,
x
,
y
,
operation
):
def
combine
(
self
,
destination
,
source
,
x
,
y
,
operation
):
output
=
destination
.
clone
()
output
=
destination
.
reshape
((
-
1
,
destination
.
shape
[
-
2
],
destination
.
shape
[
-
1
])).
clone
()
source
=
source
.
reshape
((
-
1
,
source
.
shape
[
-
2
],
source
.
shape
[
-
1
]))
left
,
top
=
(
x
,
y
,)
left
,
top
=
(
x
,
y
,)
right
,
bottom
=
(
min
(
left
+
source
.
shape
[
1
],
destination
.
shape
[
1
]),
min
(
top
+
source
.
shape
[
0
],
destination
.
shape
[
0
]))
right
,
bottom
=
(
min
(
left
+
source
.
shape
[
-
1
],
destination
.
shape
[
-
1
]),
min
(
top
+
source
.
shape
[
-
2
],
destination
.
shape
[
-
2
]))
visible_width
,
visible_height
=
(
right
-
left
,
bottom
-
top
,)
visible_width
,
visible_height
=
(
right
-
left
,
bottom
-
top
,)
source_portion
=
source
[:
visible_height
,
:
visible_width
]
source_portion
=
source
[:
visible_height
,
:
visible_width
]
destination_portion
=
destination
[
top
:
bottom
,
left
:
right
]
destination_portion
=
destination
[
top
:
bottom
,
left
:
right
]
if
operation
==
"multiply"
:
if
operation
==
"multiply"
:
output
[
top
:
bottom
,
left
:
right
]
=
destination_portion
*
source_portion
output
[
:,
top
:
bottom
,
left
:
right
]
=
destination_portion
*
source_portion
elif
operation
==
"add"
:
elif
operation
==
"add"
:
output
[
top
:
bottom
,
left
:
right
]
=
destination_portion
+
source_portion
output
[
:,
top
:
bottom
,
left
:
right
]
=
destination_portion
+
source_portion
elif
operation
==
"subtract"
:
elif
operation
==
"subtract"
:
output
[
top
:
bottom
,
left
:
right
]
=
destination_portion
-
source_portion
output
[
:,
top
:
bottom
,
left
:
right
]
=
destination_portion
-
source_portion
elif
operation
==
"and"
:
elif
operation
==
"and"
:
output
[
top
:
bottom
,
left
:
right
]
=
torch
.
bitwise_and
(
destination_portion
.
round
().
bool
(),
source_portion
.
round
().
bool
()).
float
()
output
[
:,
top
:
bottom
,
left
:
right
]
=
torch
.
bitwise_and
(
destination_portion
.
round
().
bool
(),
source_portion
.
round
().
bool
()).
float
()
elif
operation
==
"or"
:
elif
operation
==
"or"
:
output
[
top
:
bottom
,
left
:
right
]
=
torch
.
bitwise_or
(
destination_portion
.
round
().
bool
(),
source_portion
.
round
().
bool
()).
float
()
output
[
:,
top
:
bottom
,
left
:
right
]
=
torch
.
bitwise_or
(
destination_portion
.
round
().
bool
(),
source_portion
.
round
().
bool
()).
float
()
elif
operation
==
"xor"
:
elif
operation
==
"xor"
:
output
[
top
:
bottom
,
left
:
right
]
=
torch
.
bitwise_xor
(
destination_portion
.
round
().
bool
(),
source_portion
.
round
().
bool
()).
float
()
output
[
:,
top
:
bottom
,
left
:
right
]
=
torch
.
bitwise_xor
(
destination_portion
.
round
().
bool
(),
source_portion
.
round
().
bool
()).
float
()
output
=
torch
.
clamp
(
output
,
0.0
,
1.0
)
output
=
torch
.
clamp
(
output
,
0.0
,
1.0
)
...
@@ -278,7 +280,7 @@ class FeatherMask:
...
@@ -278,7 +280,7 @@ class FeatherMask:
FUNCTION
=
"feather"
FUNCTION
=
"feather"
def
feather
(
self
,
mask
,
left
,
top
,
right
,
bottom
):
def
feather
(
self
,
mask
,
left
,
top
,
right
,
bottom
):
output
=
mask
.
clone
()
output
=
mask
.
reshape
((
-
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
])).
clone
()
left
=
min
(
left
,
output
.
shape
[
1
])
left
=
min
(
left
,
output
.
shape
[
1
])
right
=
min
(
right
,
output
.
shape
[
1
])
right
=
min
(
right
,
output
.
shape
[
1
])
...
@@ -287,19 +289,19 @@ class FeatherMask:
...
@@ -287,19 +289,19 @@ class FeatherMask:
for
x
in
range
(
left
):
for
x
in
range
(
left
):
feather_rate
=
(
x
+
1.0
)
/
left
feather_rate
=
(
x
+
1.0
)
/
left
output
[:,
x
]
*=
feather_rate
output
[:,
:,
x
]
*=
feather_rate
for
x
in
range
(
right
):
for
x
in
range
(
right
):
feather_rate
=
(
x
+
1
)
/
right
feather_rate
=
(
x
+
1
)
/
right
output
[:,
-
x
]
*=
feather_rate
output
[:,
:,
-
x
]
*=
feather_rate
for
y
in
range
(
top
):
for
y
in
range
(
top
):
feather_rate
=
(
y
+
1
)
/
top
feather_rate
=
(
y
+
1
)
/
top
output
[
y
,
:]
*=
feather_rate
output
[
:,
y
,
:]
*=
feather_rate
for
y
in
range
(
bottom
):
for
y
in
range
(
bottom
):
feather_rate
=
(
y
+
1
)
/
bottom
feather_rate
=
(
y
+
1
)
/
bottom
output
[
-
y
,
:]
*=
feather_rate
output
[
:,
-
y
,
:]
*=
feather_rate
return
(
output
,)
return
(
output
,)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment