Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b4f434ee
Commit
b4f434ee
authored
May 30, 2023
by
space-nuko
Browse files
Preview sampled images with TAESD
parent
2ec980bb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
323 additions
and
51 deletions
+323
-51
comfy/taesd/taesd.py
comfy/taesd/taesd.py
+65
-0
comfy/utils.py
comfy/utils.py
+2
-2
main.py
main.py
+5
-1
nodes.py
nodes.py
+111
-8
server.py
server.py
+32
-5
web/extensions/core/colorPalette.js
web/extensions/core/colorPalette.js
+1
-0
web/scripts/api.js
web/scripts/api.js
+56
-28
web/scripts/app.js
web/scripts/app.js
+51
-7
No files found.
comfy/taesd/taesd.py
0 → 100644
View file @
b4f434ee
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import
torch
import
torch.nn
as
nn
def
conv
(
n_in
,
n_out
,
**
kwargs
):
return
nn
.
Conv2d
(
n_in
,
n_out
,
3
,
padding
=
1
,
**
kwargs
)
class
Clamp
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
tanh
(
x
/
3
)
*
3
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_in
,
n_out
):
super
().
__init__
()
self
.
conv
=
nn
.
Sequential
(
conv
(
n_in
,
n_out
),
nn
.
ReLU
(),
conv
(
n_out
,
n_out
),
nn
.
ReLU
(),
conv
(
n_out
,
n_out
))
self
.
skip
=
nn
.
Conv2d
(
n_in
,
n_out
,
1
,
bias
=
False
)
if
n_in
!=
n_out
else
nn
.
Identity
()
self
.
fuse
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
fuse
(
self
.
conv
(
x
)
+
self
.
skip
(
x
))
def
Encoder
():
return
nn
.
Sequential
(
conv
(
3
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
4
),
)
def
Decoder
():
return
nn
.
Sequential
(
Clamp
(),
conv
(
4
,
64
),
nn
.
ReLU
(),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
conv
(
64
,
3
),
)
class
TAESD
(
nn
.
Module
):
latent_magnitude
=
3
latent_shift
=
0.5
def
__init__
(
self
,
encoder_path
=
"taesd_encoder.pth"
,
decoder_path
=
"taesd_decoder.pth"
):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super
().
__init__
()
self
.
encoder
=
Encoder
()
self
.
decoder
=
Decoder
()
if
encoder_path
is
not
None
:
self
.
encoder
.
load_state_dict
(
torch
.
load
(
encoder_path
,
map_location
=
"cpu"
))
if
decoder_path
is
not
None
:
self
.
decoder
.
load_state_dict
(
torch
.
load
(
decoder_path
,
map_location
=
"cpu"
))
@
staticmethod
def
scale_latents
(
x
):
"""raw latents -> [0, 1]"""
return
x
.
div
(
2
*
TAESD
.
latent_magnitude
).
add
(
TAESD
.
latent_shift
).
clamp
(
0
,
1
)
@
staticmethod
def
unscale_latents
(
x
):
"""[0, 1] -> raw latents"""
return
x
.
sub
(
TAESD
.
latent_shift
).
mul
(
2
*
TAESD
.
latent_magnitude
)
comfy/utils.py
View file @
b4f434ee
...
@@ -197,14 +197,14 @@ class ProgressBar:
...
@@ -197,14 +197,14 @@ class ProgressBar:
self
.
current
=
0
self
.
current
=
0
self
.
hook
=
PROGRESS_BAR_HOOK
self
.
hook
=
PROGRESS_BAR_HOOK
def
update_absolute
(
self
,
value
,
total
=
None
):
def
update_absolute
(
self
,
value
,
total
=
None
,
preview
=
None
):
if
total
is
not
None
:
if
total
is
not
None
:
self
.
total
=
total
self
.
total
=
total
if
value
>
self
.
total
:
if
value
>
self
.
total
:
value
=
self
.
total
value
=
self
.
total
self
.
current
=
value
self
.
current
=
value
if
self
.
hook
is
not
None
:
if
self
.
hook
is
not
None
:
self
.
hook
(
self
.
current
,
self
.
total
)
self
.
hook
(
self
.
current
,
self
.
total
,
preview
)
def
update
(
self
,
value
):
def
update
(
self
,
value
):
self
.
update_absolute
(
self
.
current
+
value
)
self
.
update_absolute
(
self
.
current
+
value
)
main.py
View file @
b4f434ee
...
@@ -26,6 +26,7 @@ import yaml
...
@@ -26,6 +26,7 @@ import yaml
import
execution
import
execution
import
folder_paths
import
folder_paths
import
server
import
server
from
server
import
BinaryEventTypes
from
nodes
import
init_custom_nodes
from
nodes
import
init_custom_nodes
...
@@ -40,8 +41,11 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
...
@@ -40,8 +41,11 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await
asyncio
.
gather
(
server
.
start
(
address
,
port
,
verbose
,
call_on_start
),
server
.
publish_loop
())
await
asyncio
.
gather
(
server
.
start
(
address
,
port
,
verbose
,
call_on_start
),
server
.
publish_loop
())
def
hijack_progress
(
server
):
def
hijack_progress
(
server
):
def
hook
(
value
,
total
):
def
hook
(
value
,
total
,
preview_bytes_jpeg
):
server
.
send_sync
(
"progress"
,
{
"value"
:
value
,
"max"
:
total
},
server
.
client_id
)
server
.
send_sync
(
"progress"
,
{
"value"
:
value
,
"max"
:
total
},
server
.
client_id
)
if
preview_bytes_jpeg
is
not
None
:
server
.
send_sync
(
BinaryEventTypes
.
PREVIEW_IMAGE
,
preview_bytes_jpeg
,
server
.
client_id
)
pass
comfy
.
utils
.
set_progress_bar_global_hook
(
hook
)
comfy
.
utils
.
set_progress_bar_global_hook
(
hook
)
def
cleanup_temp
():
def
cleanup_temp
():
...
...
nodes.py
View file @
b4f434ee
...
@@ -7,6 +7,8 @@ import hashlib
...
@@ -7,6 +7,8 @@ import hashlib
import
traceback
import
traceback
import
math
import
math
import
time
import
time
import
struct
from
io
import
BytesIO
from
PIL
import
Image
,
ImageOps
from
PIL
import
Image
,
ImageOps
from
PIL.PngImagePlugin
import
PngInfo
from
PIL.PngImagePlugin
import
PngInfo
...
@@ -22,6 +24,7 @@ import comfy.samplers
...
@@ -22,6 +24,7 @@ import comfy.samplers
import
comfy.sample
import
comfy.sample
import
comfy.sd
import
comfy.sd
import
comfy.utils
import
comfy.utils
from
comfy.taesd.taesd
import
TAESD
import
comfy.clip_vision
import
comfy.clip_vision
...
@@ -38,6 +41,7 @@ def interrupt_processing(value=True):
...
@@ -38,6 +41,7 @@ def interrupt_processing(value=True):
comfy
.
model_management
.
interrupt_current_processing
(
value
)
comfy
.
model_management
.
interrupt_current_processing
(
value
)
MAX_RESOLUTION
=
8192
MAX_RESOLUTION
=
8192
MAX_PREVIEW_RESOLUTION
=
512
class
CLIPTextEncode
:
class
CLIPTextEncode
:
@
classmethod
@
classmethod
...
@@ -171,6 +175,21 @@ class VAEDecodeTiled:
...
@@ -171,6 +175,21 @@ class VAEDecodeTiled:
def
decode
(
self
,
vae
,
samples
):
def
decode
(
self
,
vae
,
samples
):
return
(
vae
.
decode_tiled
(
samples
[
"samples"
]),
)
return
(
vae
.
decode_tiled
(
samples
[
"samples"
]),
)
class
TAESDDecode
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"samples"
:
(
"LATENT"
,
),
"taesd"
:
(
"TAESD"
,
)}}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"decode"
CATEGORY
=
"latent"
def
decode
(
self
,
taesd
,
samples
):
device
=
comfy
.
model_management
.
get_torch_device
()
# [B, C, H, W] -> [B, H, W, C]
pixels
=
taesd
.
decoder
(
samples
[
"samples"
].
to
(
device
)).
permute
(
0
,
2
,
3
,
1
).
detach
().
clamp
(
0
,
1
)
return
(
pixels
,
)
class
VAEEncode
:
class
VAEEncode
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -248,6 +267,21 @@ class VAEEncodeForInpaint:
...
@@ -248,6 +267,21 @@ class VAEEncodeForInpaint:
return
({
"samples"
:
t
,
"noise_mask"
:
(
mask_erosion
[:,:,:
x
,:
y
].
round
())},
)
return
({
"samples"
:
t
,
"noise_mask"
:
(
mask_erosion
[:,:,:
x
,:
y
].
round
())},
)
class
TAESDEncode
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"pixels"
:
(
"IMAGE"
,
),
"taesd"
:
(
"TAESD"
,
)}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"encode"
CATEGORY
=
"latent"
def
encode
(
self
,
taesd
,
pixels
):
device
=
comfy
.
model_management
.
get_torch_device
()
# [B, H, W, C] -> [B, C, H, W]
samples
=
taesd
.
encoder
(
pixels
.
permute
(
0
,
3
,
1
,
2
).
to
(
device
)).
to
(
device
)
return
({
"samples"
:
samples
},
)
class
SaveLatent
:
class
SaveLatent
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -464,6 +498,26 @@ class VAELoader:
...
@@ -464,6 +498,26 @@ class VAELoader:
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
return
(
vae
,)
return
(
vae
,)
class
TAESDLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
model_list
=
folder_paths
.
get_filename_list
(
"taesd"
)
return
{
"required"
:
{
"encoder_name"
:
(
model_list
,
{
"default"
:
"taesd_encoder.pth"
}),
"decoder_name"
:
(
model_list
,
{
"default"
:
"taesd_decoder.pth"
})
}}
RETURN_TYPES
=
(
"TAESD"
,)
FUNCTION
=
"load_taesd"
CATEGORY
=
"loaders"
def
load_taesd
(
self
,
encoder_name
,
decoder_name
):
device
=
comfy
.
model_management
.
get_torch_device
()
encoder_path
=
folder_paths
.
get_full_path
(
"taesd"
,
encoder_name
)
decoder_path
=
folder_paths
.
get_full_path
(
"taesd"
,
decoder_name
)
taesd
=
TAESD
(
encoder_path
,
decoder_path
).
to
(
device
)
return
(
taesd
,)
class
ControlNetLoader
:
class
ControlNetLoader
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -931,7 +985,37 @@ class SetLatentNoiseMask:
...
@@ -931,7 +985,37 @@ class SetLatentNoiseMask:
s
[
"noise_mask"
]
=
mask
.
reshape
((
-
1
,
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]))
s
[
"noise_mask"
]
=
mask
.
reshape
((
-
1
,
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]))
return
(
s
,)
return
(
s
,)
def
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
):
def
decode_latent_to_preview_image
(
taesd
,
device
,
preview_format
,
x0
):
x_sample
=
taesd
.
decoder
(
x0
.
to
(
device
))[
0
].
detach
()
x_sample
=
taesd
.
unscale_latents
(
x_sample
)
# returns value in [-2, 2]
x_sample
=
x_sample
*
0.5
x_sample
=
torch
.
clamp
((
x_sample
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_sample
=
255.
*
np
.
moveaxis
(
x_sample
.
cpu
().
numpy
(),
0
,
2
)
x_sample
=
x_sample
.
astype
(
np
.
uint8
)
preview_image
=
Image
.
fromarray
(
x_sample
)
if
preview_image
.
size
[
0
]
>
MAX_PREVIEW_RESOLUTION
or
preview_image
.
size
[
1
]
>
MAX_PREVIEW_RESOLUTION
:
preview_image
.
thumbnail
((
MAX_PREVIEW_RESOLUTION
,
MAX_PREVIEW_RESOLUTION
),
Image
.
ANTIALIAS
)
preview_type
=
1
if
preview_format
==
"JPEG"
:
preview_type
=
1
elif
preview_format
==
"PNG"
:
preview_type
=
2
bytesIO
=
BytesIO
()
header
=
struct
.
pack
(
">I"
,
preview_type
)
bytesIO
.
write
(
header
)
preview_image
.
save
(
bytesIO
,
format
=
preview_format
)
preview_bytes
=
bytesIO
.
getvalue
()
return
preview_bytes
def
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
taesd
=
None
):
device
=
comfy
.
model_management
.
get_torch_device
()
device
=
comfy
.
model_management
.
get_torch_device
()
latent_image
=
latent
[
"samples"
]
latent_image
=
latent
[
"samples"
]
...
@@ -945,9 +1029,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
...
@@ -945,9 +1029,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if
"noise_mask"
in
latent
:
if
"noise_mask"
in
latent
:
noise_mask
=
latent
[
"noise_mask"
]
noise_mask
=
latent
[
"noise_mask"
]
preview_format
=
"JPEG"
if
preview_format
not
in
[
"JPEG"
,
"PNG"
]:
preview_format
=
"JPEG"
pbar
=
comfy
.
utils
.
ProgressBar
(
steps
)
pbar
=
comfy
.
utils
.
ProgressBar
(
steps
)
def
callback
(
step
,
x0
,
x
,
total_steps
):
def
callback
(
step
,
x0
,
x
,
total_steps
):
pbar
.
update_absolute
(
step
+
1
,
total_steps
)
preview_bytes
=
None
if
taesd
:
preview_bytes
=
decode_latent_to_preview_image
(
taesd
,
device
,
preview_format
,
x0
)
pbar
.
update_absolute
(
step
+
1
,
total_steps
,
preview_bytes
)
samples
=
comfy
.
sample
.
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
samples
=
comfy
.
sample
.
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_step
,
last_step
=
last_step
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_step
,
last_step
=
last_step
,
...
@@ -970,15 +1061,18 @@ class KSampler:
...
@@ -970,15 +1061,18 @@ class KSampler:
"negative"
:
(
"CONDITIONING"
,
),
"negative"
:
(
"CONDITIONING"
,
),
"latent_image"
:
(
"LATENT"
,
),
"latent_image"
:
(
"LATENT"
,
),
"denoise"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.01
}),
"denoise"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.01
}),
}}
},
"optional"
:
{
"taesd"
:
(
"TAESD"
,)
}}
RETURN_TYPES
=
(
"LATENT"
,)
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"sample"
FUNCTION
=
"sample"
CATEGORY
=
"sampling"
CATEGORY
=
"sampling"
def
sample
(
self
,
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
):
def
sample
(
self
,
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
taesd
=
None
):
return
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
)
return
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
taesd
=
taesd
)
class
KSamplerAdvanced
:
class
KSamplerAdvanced
:
@
classmethod
@
classmethod
...
@@ -997,21 +1091,24 @@ class KSamplerAdvanced:
...
@@ -997,21 +1091,24 @@ class KSamplerAdvanced:
"start_at_step"
:
(
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
10000
}),
"start_at_step"
:
(
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
10000
}),
"end_at_step"
:
(
"INT"
,
{
"default"
:
10000
,
"min"
:
0
,
"max"
:
10000
}),
"end_at_step"
:
(
"INT"
,
{
"default"
:
10000
,
"min"
:
0
,
"max"
:
10000
}),
"return_with_leftover_noise"
:
([
"disable"
,
"enable"
],
),
"return_with_leftover_noise"
:
([
"disable"
,
"enable"
],
),
}}
},
"optional"
:
{
"taesd"
:
(
"TAESD"
,)
}}
RETURN_TYPES
=
(
"LATENT"
,)
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"sample"
FUNCTION
=
"sample"
CATEGORY
=
"sampling"
CATEGORY
=
"sampling"
def
sample
(
self
,
model
,
add_noise
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
start_at_step
,
end_at_step
,
return_with_leftover_noise
,
denoise
=
1.0
):
def
sample
(
self
,
model
,
add_noise
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
start_at_step
,
end_at_step
,
return_with_leftover_noise
,
denoise
=
1.0
,
taesd
=
None
):
force_full_denoise
=
True
force_full_denoise
=
True
if
return_with_leftover_noise
==
"enable"
:
if
return_with_leftover_noise
==
"enable"
:
force_full_denoise
=
False
force_full_denoise
=
False
disable_noise
=
False
disable_noise
=
False
if
add_noise
==
"disable"
:
if
add_noise
==
"disable"
:
disable_noise
=
True
disable_noise
=
True
return
common_ksampler
(
model
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_at_step
,
last_step
=
end_at_step
,
force_full_denoise
=
force_full_denoise
)
return
common_ksampler
(
model
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_at_step
,
last_step
=
end_at_step
,
force_full_denoise
=
force_full_denoise
,
taesd
=
taesd
)
class
SaveImage
:
class
SaveImage
:
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -1270,6 +1367,9 @@ NODE_CLASS_MAPPINGS = {
...
@@ -1270,6 +1367,9 @@ NODE_CLASS_MAPPINGS = {
"VAEEncode"
:
VAEEncode
,
"VAEEncode"
:
VAEEncode
,
"VAEEncodeForInpaint"
:
VAEEncodeForInpaint
,
"VAEEncodeForInpaint"
:
VAEEncodeForInpaint
,
"VAELoader"
:
VAELoader
,
"VAELoader"
:
VAELoader
,
"TAESDDecode"
:
TAESDDecode
,
"TAESDEncode"
:
TAESDEncode
,
"TAESDLoader"
:
TAESDLoader
,
"EmptyLatentImage"
:
EmptyLatentImage
,
"EmptyLatentImage"
:
EmptyLatentImage
,
"LatentUpscale"
:
LatentUpscale
,
"LatentUpscale"
:
LatentUpscale
,
"LatentUpscaleBy"
:
LatentUpscaleBy
,
"LatentUpscaleBy"
:
LatentUpscaleBy
,
...
@@ -1324,6 +1424,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
...
@@ -1324,6 +1424,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoader"
:
"Load Checkpoint (With Config)"
,
"CheckpointLoader"
:
"Load Checkpoint (With Config)"
,
"CheckpointLoaderSimple"
:
"Load Checkpoint"
,
"CheckpointLoaderSimple"
:
"Load Checkpoint"
,
"VAELoader"
:
"Load VAE"
,
"VAELoader"
:
"Load VAE"
,
"TAESDLoader"
:
"Load TAESD"
,
"LoraLoader"
:
"Load LoRA"
,
"LoraLoader"
:
"Load LoRA"
,
"CLIPLoader"
:
"Load CLIP"
,
"CLIPLoader"
:
"Load CLIP"
,
"ControlNetLoader"
:
"Load ControlNet Model"
,
"ControlNetLoader"
:
"Load ControlNet Model"
,
...
@@ -1346,6 +1447,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
...
@@ -1346,6 +1447,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SetLatentNoiseMask"
:
"Set Latent Noise Mask"
,
"SetLatentNoiseMask"
:
"Set Latent Noise Mask"
,
"VAEDecode"
:
"VAE Decode"
,
"VAEDecode"
:
"VAE Decode"
,
"VAEEncode"
:
"VAE Encode"
,
"VAEEncode"
:
"VAE Encode"
,
"TAESDDecode"
:
"TAESD Decode"
,
"TAESDEncode"
:
"TAESD Encode"
,
"LatentRotate"
:
"Rotate Latent"
,
"LatentRotate"
:
"Rotate Latent"
,
"LatentFlip"
:
"Flip Latent"
,
"LatentFlip"
:
"Flip Latent"
,
"LatentCrop"
:
"Crop Latent"
,
"LatentCrop"
:
"Crop Latent"
,
...
...
server.py
View file @
b4f434ee
...
@@ -7,6 +7,7 @@ import execution
...
@@ -7,6 +7,7 @@ import execution
import
uuid
import
uuid
import
json
import
json
import
glob
import
glob
import
struct
from
PIL
import
Image
from
PIL
import
Image
from
io
import
BytesIO
from
io
import
BytesIO
...
@@ -25,6 +26,11 @@ from comfy.cli_args import args
...
@@ -25,6 +26,11 @@ from comfy.cli_args import args
import
comfy.utils
import
comfy.utils
import
comfy.model_management
import
comfy.model_management
class
BinaryEventTypes
:
PREVIEW_IMAGE
=
1
@
web
.
middleware
@
web
.
middleware
async
def
cache_control
(
request
:
web
.
Request
,
handler
):
async
def
cache_control
(
request
:
web
.
Request
,
handler
):
response
:
web
.
Response
=
await
handler
(
request
)
response
:
web
.
Response
=
await
handler
(
request
)
...
@@ -457,16 +463,37 @@ class PromptServer():
...
@@ -457,16 +463,37 @@ class PromptServer():
return
prompt_info
return
prompt_info
async
def
send
(
self
,
event
,
data
,
sid
=
None
):
async
def
send
(
self
,
event
,
data
,
sid
=
None
):
if
isinstance
(
data
,
(
bytes
,
bytearray
)):
await
self
.
send_bytes
(
event
,
data
,
sid
)
else
:
await
self
.
send_json
(
event
,
data
,
sid
)
def
encode_bytes
(
self
,
event
,
data
):
if
not
isinstance
(
event
,
int
):
raise
RuntimeError
(
f
"Binary event types must be integers, got
{
event
}
"
)
packed
=
struct
.
pack
(
">I"
,
event
)
message
=
bytearray
(
packed
)
message
.
extend
(
data
)
return
message
async
def
send_bytes
(
self
,
event
,
data
,
sid
=
None
):
message
=
self
.
encode_bytes
(
event
,
data
)
if
sid
is
None
:
for
ws
in
self
.
sockets
.
values
():
await
ws
.
send_bytes
(
message
)
elif
sid
in
self
.
sockets
:
await
self
.
sockets
[
sid
].
send_bytes
(
message
)
async
def
send_json
(
self
,
event
,
data
,
sid
=
None
):
message
=
{
"type"
:
event
,
"data"
:
data
}
message
=
{
"type"
:
event
,
"data"
:
data
}
if
isinstance
(
message
,
str
)
==
False
:
message
=
json
.
dumps
(
message
)
if
sid
is
None
:
if
sid
is
None
:
for
ws
in
self
.
sockets
.
values
():
for
ws
in
self
.
sockets
.
values
():
await
ws
.
send_
str
(
message
)
await
ws
.
send_
json
(
message
)
elif
sid
in
self
.
sockets
:
elif
sid
in
self
.
sockets
:
await
self
.
sockets
[
sid
].
send_
str
(
message
)
await
self
.
sockets
[
sid
].
send_
json
(
message
)
def
send_sync
(
self
,
event
,
data
,
sid
=
None
):
def
send_sync
(
self
,
event
,
data
,
sid
=
None
):
self
.
loop
.
call_soon_threadsafe
(
self
.
loop
.
call_soon_threadsafe
(
...
...
web/extensions/core/colorPalette.js
View file @
b4f434ee
...
@@ -21,6 +21,7 @@ const colorPalettes = {
...
@@ -21,6 +21,7 @@ const colorPalettes = {
"
MODEL
"
:
"
#B39DDB
"
,
// light lavender-purple
"
MODEL
"
:
"
#B39DDB
"
,
// light lavender-purple
"
STYLE_MODEL
"
:
"
#C2FFAE
"
,
// light green-yellow
"
STYLE_MODEL
"
:
"
#C2FFAE
"
,
// light green-yellow
"
VAE
"
:
"
#FF6E6E
"
,
// bright red
"
VAE
"
:
"
#FF6E6E
"
,
// bright red
"
TAESD
"
:
"
#DCC274
"
,
// cheesecake
},
},
"
litegraph_base
"
:
{
"
litegraph_base
"
:
{
"
NODE_TITLE_COLOR
"
:
"
#999
"
,
"
NODE_TITLE_COLOR
"
:
"
#999
"
,
...
...
web/scripts/api.js
View file @
b4f434ee
...
@@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
...
@@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
this
.
socket
=
new
WebSocket
(
this
.
socket
=
new
WebSocket
(
`ws
${
window
.
location
.
protocol
===
"
https:
"
?
"
s
"
:
""
}
://
${
location
.
host
}
/ws
${
existingSession
}
`
`ws
${
window
.
location
.
protocol
===
"
https:
"
?
"
s
"
:
""
}
://
${
location
.
host
}
/ws
${
existingSession
}
`
);
);
this
.
socket
.
binaryType
=
"
arraybuffer
"
;
this
.
socket
.
addEventListener
(
"
open
"
,
()
=>
{
this
.
socket
.
addEventListener
(
"
open
"
,
()
=>
{
opened
=
true
;
opened
=
true
;
...
@@ -70,39 +71,66 @@ class ComfyApi extends EventTarget {
...
@@ -70,39 +71,66 @@ class ComfyApi extends EventTarget {
this
.
socket
.
addEventListener
(
"
message
"
,
(
event
)
=>
{
this
.
socket
.
addEventListener
(
"
message
"
,
(
event
)
=>
{
try
{
try
{
const
msg
=
JSON
.
parse
(
event
.
data
);
if
(
event
.
data
instanceof
ArrayBuffer
)
{
switch
(
msg
.
type
)
{
const
view
=
new
DataView
(
event
.
data
);
case
"
status
"
:
const
eventType
=
view
.
getUint32
(
0
);
if
(
msg
.
data
.
sid
)
{
const
buffer
=
event
.
data
.
slice
(
4
);
this
.
clientId
=
msg
.
data
.
sid
;
console
.
error
(
"
BINARY
"
,
eventType
);
window
.
name
=
this
.
clientId
;
switch
(
eventType
)
{
case
1
:
const
view2
=
new
DataView
(
event
.
data
);
const
imageType
=
view2
.
getUint32
(
0
)
let
imageMime
switch
(
imageType
)
{
case
1
:
default
:
imageMime
=
"
image/jpeg
"
;
break
;
case
2
:
imageMime
=
"
image/png
"
}
}
this
.
dispatchEvent
(
new
CustomEvent
(
"
status
"
,
{
detail
:
msg
.
data
.
status
}));
const
jpegBlob
=
new
Blob
([
buffer
.
slice
(
4
)],
{
type
:
imageMime
});
break
;
this
.
dispatchEvent
(
new
CustomEvent
(
"
b_preview
"
,
{
detail
:
jpegBlob
}));
case
"
progress
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
progress
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
executing
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executing
"
,
{
detail
:
msg
.
data
.
node
}));
break
;
case
"
executed
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executed
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_start
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_start
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_error
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_error
"
,
{
detail
:
msg
.
data
}));
break
;
break
;
default
:
default
:
if
(
this
.
#
registered
.
has
(
msg
.
type
))
{
throw
new
Error
(
`Unknown binary websocket message of type
${
eventType
}
`
);
this
.
dispatchEvent
(
new
CustomEvent
(
msg
.
type
,
{
detail
:
msg
.
data
}));
}
}
else
{
}
throw
new
Error
(
"
Unknown message type
"
);
else
{
}
const
msg
=
JSON
.
parse
(
event
.
data
);
switch
(
msg
.
type
)
{
case
"
status
"
:
if
(
msg
.
data
.
sid
)
{
this
.
clientId
=
msg
.
data
.
sid
;
window
.
name
=
this
.
clientId
;
}
this
.
dispatchEvent
(
new
CustomEvent
(
"
status
"
,
{
detail
:
msg
.
data
.
status
}));
break
;
case
"
progress
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
progress
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
executing
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executing
"
,
{
detail
:
msg
.
data
.
node
}));
break
;
case
"
executed
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executed
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_start
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_start
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_error
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_error
"
,
{
detail
:
msg
.
data
}));
break
;
default
:
if
(
this
.
#
registered
.
has
(
msg
.
type
))
{
this
.
dispatchEvent
(
new
CustomEvent
(
msg
.
type
,
{
detail
:
msg
.
data
}));
}
else
{
throw
new
Error
(
`Unknown message type
${
msg
.
type
}
`
);
}
}
}
}
}
catch
(
error
)
{
}
catch
(
error
)
{
console
.
warn
(
"
Unhandled message:
"
,
event
.
data
);
console
.
warn
(
"
Unhandled message:
"
,
event
.
data
,
error
);
}
}
});
});
}
}
...
...
web/scripts/app.js
View file @
b4f434ee
...
@@ -44,6 +44,12 @@ export class ComfyApp {
...
@@ -44,6 +44,12 @@ export class ComfyApp {
*/
*/
this
.
nodeOutputs
=
{};
this
.
nodeOutputs
=
{};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
*/
this
.
nodePreviewImages
=
{};
/**
/**
* If the shift key on the keyboard is pressed
* If the shift key on the keyboard is pressed
* @type {boolean}
* @type {boolean}
...
@@ -367,29 +373,52 @@ export class ComfyApp {
...
@@ -367,29 +373,52 @@ export class ComfyApp {
node
.
prototype
.
onDrawBackground
=
function
(
ctx
)
{
node
.
prototype
.
onDrawBackground
=
function
(
ctx
)
{
if
(
!
this
.
flags
.
collapsed
)
{
if
(
!
this
.
flags
.
collapsed
)
{
let
imgURLs
=
[]
let
imagesChanged
=
false
const
output
=
app
.
nodeOutputs
[
this
.
id
+
""
];
const
output
=
app
.
nodeOutputs
[
this
.
id
+
""
];
if
(
output
&&
output
.
images
)
{
if
(
output
&&
output
.
images
)
{
if
(
this
.
images
!==
output
.
images
)
{
if
(
this
.
images
!==
output
.
images
)
{
this
.
images
=
output
.
images
;
this
.
images
=
output
.
images
;
this
.
imgs
=
null
;
imagesChanged
=
true
;
this
.
imageIndex
=
null
;
imgURLs
=
imgURLs
.
concat
(
output
.
images
.
map
(
params
=>
{
return
"
/view?
"
+
new
URLSearchParams
(
src
).
toString
()
+
app
.
getPreviewFormatParam
();
}))
}
}
const
preview
=
app
.
nodePreviewImages
[
this
.
id
+
""
]
if
(
this
.
preview
!==
preview
)
{
this
.
preview
=
preview
imagesChanged
=
true
;
if
(
preview
!=
null
)
{
imgURLs
.
push
(
preview
);
}
}
if
(
imagesChanged
)
{
this
.
imageIndex
=
null
;
if
(
imgURLs
.
length
>
0
)
{
Promise
.
all
(
Promise
.
all
(
output
.
image
s
.
map
((
src
)
=>
{
imgURL
s
.
map
((
src
)
=>
{
return
new
Promise
((
r
)
=>
{
return
new
Promise
((
r
)
=>
{
const
img
=
new
Image
();
const
img
=
new
Image
();
img
.
onload
=
()
=>
r
(
img
);
img
.
onload
=
()
=>
r
(
img
);
img
.
onerror
=
()
=>
r
(
null
);
img
.
onerror
=
()
=>
r
(
null
);
img
.
src
=
"
/view?
"
+
new
URLSearchParams
(
src
).
toString
()
+
app
.
getPreviewFormatParam
();
img
.
src
=
src
});
});
})
})
).
then
((
imgs
)
=>
{
).
then
((
imgs
)
=>
{
if
(
this
.
images
===
output
.
images
)
{
if
(
(
!
output
||
this
.
images
===
output
.
images
)
&&
(
!
preview
||
this
.
preview
===
preview
))
{
this
.
imgs
=
imgs
.
filter
(
Boolean
);
this
.
imgs
=
imgs
.
filter
(
Boolean
);
this
.
setSizeForImage
?.();
this
.
setSizeForImage
?.();
app
.
graph
.
setDirtyCanvas
(
true
);
app
.
graph
.
setDirtyCanvas
(
true
);
}
}
});
});
}
}
else
{
this
.
imgs
=
null
;
}
}
}
if
(
this
.
imgs
&&
this
.
imgs
.
length
)
{
if
(
this
.
imgs
&&
this
.
imgs
.
length
)
{
...
@@ -901,17 +930,20 @@ export class ComfyApp {
...
@@ -901,17 +930,20 @@ export class ComfyApp {
this
.
progress
=
null
;
this
.
progress
=
null
;
this
.
runningNodeId
=
detail
;
this
.
runningNodeId
=
detail
;
this
.
graph
.
setDirtyCanvas
(
true
,
false
);
this
.
graph
.
setDirtyCanvas
(
true
,
false
);
delete
this
.
nodePreviewImages
[
this
.
runningNodeId
]
});
});
api
.
addEventListener
(
"
executed
"
,
({
detail
})
=>
{
api
.
addEventListener
(
"
executed
"
,
({
detail
})
=>
{
this
.
nodeOutputs
[
detail
.
node
]
=
detail
.
output
;
this
.
nodeOutputs
[
detail
.
node
]
=
detail
.
output
;
const
node
=
this
.
graph
.
getNodeById
(
detail
.
node
);
const
node
=
this
.
graph
.
getNodeById
(
detail
.
node
);
if
(
node
?.
onExecuted
)
{
if
(
node
)
{
node
.
onExecuted
(
detail
.
output
);
if
(
node
.
onExecuted
)
node
.
onExecuted
(
detail
.
output
);
}
}
});
});
api
.
addEventListener
(
"
execution_start
"
,
({
detail
})
=>
{
api
.
addEventListener
(
"
execution_start
"
,
({
detail
})
=>
{
this
.
runningNodeId
=
null
;
this
.
lastExecutionError
=
null
this
.
lastExecutionError
=
null
});
});
...
@@ -922,6 +954,16 @@ export class ComfyApp {
...
@@ -922,6 +954,16 @@ export class ComfyApp {
this
.
canvas
.
draw
(
true
,
true
);
this
.
canvas
.
draw
(
true
,
true
);
});
});
api
.
addEventListener
(
"
b_preview
"
,
({
detail
})
=>
{
const
id
=
this
.
runningNodeId
if
(
id
==
null
)
return
;
const
blob
=
detail
const
blobUrl
=
URL
.
createObjectURL
(
blob
)
this
.
nodePreviewImages
[
id
]
=
[
blobUrl
]
});
api
.
init
();
api
.
init
();
}
}
...
@@ -1465,8 +1507,10 @@ export class ComfyApp {
...
@@ -1465,8 +1507,10 @@ export class ComfyApp {
*/
*/
clean
()
{
clean
()
{
this
.
nodeOutputs
=
{};
this
.
nodeOutputs
=
{};
this
.
nodePreviewImages
=
{}
this
.
lastPromptError
=
null
;
this
.
lastPromptError
=
null
;
this
.
lastExecutionError
=
null
;
this
.
lastExecutionError
=
null
;
this
.
runningNodeId
=
null
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment