Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b4f434ee
Commit
b4f434ee
authored
May 30, 2023
by
space-nuko
Browse files
Preview sampled images with TAESD
parent
2ec980bb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
323 additions
and
51 deletions
+323
-51
comfy/taesd/taesd.py
comfy/taesd/taesd.py
+65
-0
comfy/utils.py
comfy/utils.py
+2
-2
main.py
main.py
+5
-1
nodes.py
nodes.py
+111
-8
server.py
server.py
+32
-5
web/extensions/core/colorPalette.js
web/extensions/core/colorPalette.js
+1
-0
web/scripts/api.js
web/scripts/api.js
+56
-28
web/scripts/app.js
web/scripts/app.js
+51
-7
No files found.
comfy/taesd/taesd.py
0 → 100644
View file @
b4f434ee
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)
"""
import
torch
import
torch.nn
as
nn
def
conv
(
n_in
,
n_out
,
**
kwargs
):
return
nn
.
Conv2d
(
n_in
,
n_out
,
3
,
padding
=
1
,
**
kwargs
)
class
Clamp
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
tanh
(
x
/
3
)
*
3
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_in
,
n_out
):
super
().
__init__
()
self
.
conv
=
nn
.
Sequential
(
conv
(
n_in
,
n_out
),
nn
.
ReLU
(),
conv
(
n_out
,
n_out
),
nn
.
ReLU
(),
conv
(
n_out
,
n_out
))
self
.
skip
=
nn
.
Conv2d
(
n_in
,
n_out
,
1
,
bias
=
False
)
if
n_in
!=
n_out
else
nn
.
Identity
()
self
.
fuse
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
return
self
.
fuse
(
self
.
conv
(
x
)
+
self
.
skip
(
x
))
def
Encoder
():
return
nn
.
Sequential
(
conv
(
3
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
4
),
)
def
Decoder
():
return
nn
.
Sequential
(
Clamp
(),
conv
(
4
,
64
),
nn
.
ReLU
(),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
conv
(
64
,
3
),
)
class
TAESD
(
nn
.
Module
):
latent_magnitude
=
3
latent_shift
=
0.5
def
__init__
(
self
,
encoder_path
=
"taesd_encoder.pth"
,
decoder_path
=
"taesd_decoder.pth"
):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super
().
__init__
()
self
.
encoder
=
Encoder
()
self
.
decoder
=
Decoder
()
if
encoder_path
is
not
None
:
self
.
encoder
.
load_state_dict
(
torch
.
load
(
encoder_path
,
map_location
=
"cpu"
))
if
decoder_path
is
not
None
:
self
.
decoder
.
load_state_dict
(
torch
.
load
(
decoder_path
,
map_location
=
"cpu"
))
@
staticmethod
def
scale_latents
(
x
):
"""raw latents -> [0, 1]"""
return
x
.
div
(
2
*
TAESD
.
latent_magnitude
).
add
(
TAESD
.
latent_shift
).
clamp
(
0
,
1
)
@
staticmethod
def
unscale_latents
(
x
):
"""[0, 1] -> raw latents"""
return
x
.
sub
(
TAESD
.
latent_shift
).
mul
(
2
*
TAESD
.
latent_magnitude
)
comfy/utils.py
View file @
b4f434ee
...
...
@@ -197,14 +197,14 @@ class ProgressBar:
self
.
current
=
0
self
.
hook
=
PROGRESS_BAR_HOOK
def
update_absolute
(
self
,
value
,
total
=
None
):
def
update_absolute
(
self
,
value
,
total
=
None
,
preview
=
None
):
if
total
is
not
None
:
self
.
total
=
total
if
value
>
self
.
total
:
value
=
self
.
total
self
.
current
=
value
if
self
.
hook
is
not
None
:
self
.
hook
(
self
.
current
,
self
.
total
)
self
.
hook
(
self
.
current
,
self
.
total
,
preview
)
def
update
(
self
,
value
):
self
.
update_absolute
(
self
.
current
+
value
)
main.py
View file @
b4f434ee
...
...
@@ -26,6 +26,7 @@ import yaml
import
execution
import
folder_paths
import
server
from
server
import
BinaryEventTypes
from
nodes
import
init_custom_nodes
...
...
@@ -40,8 +41,11 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await
asyncio
.
gather
(
server
.
start
(
address
,
port
,
verbose
,
call_on_start
),
server
.
publish_loop
())
def
hijack_progress
(
server
):
def
hook
(
value
,
total
):
def
hook
(
value
,
total
,
preview_bytes_jpeg
):
server
.
send_sync
(
"progress"
,
{
"value"
:
value
,
"max"
:
total
},
server
.
client_id
)
if
preview_bytes_jpeg
is
not
None
:
server
.
send_sync
(
BinaryEventTypes
.
PREVIEW_IMAGE
,
preview_bytes_jpeg
,
server
.
client_id
)
pass
comfy
.
utils
.
set_progress_bar_global_hook
(
hook
)
def
cleanup_temp
():
...
...
nodes.py
View file @
b4f434ee
...
...
@@ -7,6 +7,8 @@ import hashlib
import
traceback
import
math
import
time
import
struct
from
io
import
BytesIO
from
PIL
import
Image
,
ImageOps
from
PIL.PngImagePlugin
import
PngInfo
...
...
@@ -22,6 +24,7 @@ import comfy.samplers
import
comfy.sample
import
comfy.sd
import
comfy.utils
from
comfy.taesd.taesd
import
TAESD
import
comfy.clip_vision
...
...
@@ -38,6 +41,7 @@ def interrupt_processing(value=True):
comfy
.
model_management
.
interrupt_current_processing
(
value
)
MAX_RESOLUTION
=
8192
MAX_PREVIEW_RESOLUTION
=
512
class
CLIPTextEncode
:
@
classmethod
...
...
@@ -171,6 +175,21 @@ class VAEDecodeTiled:
def
decode
(
self
,
vae
,
samples
):
return
(
vae
.
decode_tiled
(
samples
[
"samples"
]),
)
class
TAESDDecode
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"samples"
:
(
"LATENT"
,
),
"taesd"
:
(
"TAESD"
,
)}}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"decode"
CATEGORY
=
"latent"
def
decode
(
self
,
taesd
,
samples
):
device
=
comfy
.
model_management
.
get_torch_device
()
# [B, C, H, W] -> [B, H, W, C]
pixels
=
taesd
.
decoder
(
samples
[
"samples"
].
to
(
device
)).
permute
(
0
,
2
,
3
,
1
).
detach
().
clamp
(
0
,
1
)
return
(
pixels
,
)
class
VAEEncode
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -248,6 +267,21 @@ class VAEEncodeForInpaint:
return
({
"samples"
:
t
,
"noise_mask"
:
(
mask_erosion
[:,:,:
x
,:
y
].
round
())},
)
class
TAESDEncode
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"pixels"
:
(
"IMAGE"
,
),
"taesd"
:
(
"TAESD"
,
)}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"encode"
CATEGORY
=
"latent"
def
encode
(
self
,
taesd
,
pixels
):
device
=
comfy
.
model_management
.
get_torch_device
()
# [B, H, W, C] -> [B, C, H, W]
samples
=
taesd
.
encoder
(
pixels
.
permute
(
0
,
3
,
1
,
2
).
to
(
device
)).
to
(
device
)
return
({
"samples"
:
samples
},
)
class
SaveLatent
:
def
__init__
(
self
):
...
...
@@ -464,6 +498,26 @@ class VAELoader:
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
return
(
vae
,)
class
TAESDLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
model_list
=
folder_paths
.
get_filename_list
(
"taesd"
)
return
{
"required"
:
{
"encoder_name"
:
(
model_list
,
{
"default"
:
"taesd_encoder.pth"
}),
"decoder_name"
:
(
model_list
,
{
"default"
:
"taesd_decoder.pth"
})
}}
RETURN_TYPES
=
(
"TAESD"
,)
FUNCTION
=
"load_taesd"
CATEGORY
=
"loaders"
def
load_taesd
(
self
,
encoder_name
,
decoder_name
):
device
=
comfy
.
model_management
.
get_torch_device
()
encoder_path
=
folder_paths
.
get_full_path
(
"taesd"
,
encoder_name
)
decoder_path
=
folder_paths
.
get_full_path
(
"taesd"
,
decoder_name
)
taesd
=
TAESD
(
encoder_path
,
decoder_path
).
to
(
device
)
return
(
taesd
,)
class
ControlNetLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -931,7 +985,37 @@ class SetLatentNoiseMask:
s
[
"noise_mask"
]
=
mask
.
reshape
((
-
1
,
1
,
mask
.
shape
[
-
2
],
mask
.
shape
[
-
1
]))
return
(
s
,)
def
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
):
def
decode_latent_to_preview_image
(
taesd
,
device
,
preview_format
,
x0
):
x_sample
=
taesd
.
decoder
(
x0
.
to
(
device
))[
0
].
detach
()
x_sample
=
taesd
.
unscale_latents
(
x_sample
)
# returns value in [-2, 2]
x_sample
=
x_sample
*
0.5
x_sample
=
torch
.
clamp
((
x_sample
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_sample
=
255.
*
np
.
moveaxis
(
x_sample
.
cpu
().
numpy
(),
0
,
2
)
x_sample
=
x_sample
.
astype
(
np
.
uint8
)
preview_image
=
Image
.
fromarray
(
x_sample
)
if
preview_image
.
size
[
0
]
>
MAX_PREVIEW_RESOLUTION
or
preview_image
.
size
[
1
]
>
MAX_PREVIEW_RESOLUTION
:
preview_image
.
thumbnail
((
MAX_PREVIEW_RESOLUTION
,
MAX_PREVIEW_RESOLUTION
),
Image
.
ANTIALIAS
)
preview_type
=
1
if
preview_format
==
"JPEG"
:
preview_type
=
1
elif
preview_format
==
"PNG"
:
preview_type
=
2
bytesIO
=
BytesIO
()
header
=
struct
.
pack
(
">I"
,
preview_type
)
bytesIO
.
write
(
header
)
preview_image
.
save
(
bytesIO
,
format
=
preview_format
)
preview_bytes
=
bytesIO
.
getvalue
()
return
preview_bytes
def
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent
,
denoise
=
1.0
,
disable_noise
=
False
,
start_step
=
None
,
last_step
=
None
,
force_full_denoise
=
False
,
taesd
=
None
):
device
=
comfy
.
model_management
.
get_torch_device
()
latent_image
=
latent
[
"samples"
]
...
...
@@ -945,9 +1029,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if
"noise_mask"
in
latent
:
noise_mask
=
latent
[
"noise_mask"
]
preview_format
=
"JPEG"
if
preview_format
not
in
[
"JPEG"
,
"PNG"
]:
preview_format
=
"JPEG"
pbar
=
comfy
.
utils
.
ProgressBar
(
steps
)
def
callback
(
step
,
x0
,
x
,
total_steps
):
pbar
.
update_absolute
(
step
+
1
,
total_steps
)
preview_bytes
=
None
if
taesd
:
preview_bytes
=
decode_latent_to_preview_image
(
taesd
,
device
,
preview_format
,
x0
)
pbar
.
update_absolute
(
step
+
1
,
total_steps
,
preview_bytes
)
samples
=
comfy
.
sample
.
sample
(
model
,
noise
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_step
,
last_step
=
last_step
,
...
...
@@ -970,15 +1061,18 @@ class KSampler:
"negative"
:
(
"CONDITIONING"
,
),
"latent_image"
:
(
"LATENT"
,
),
"denoise"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.01
}),
}}
},
"optional"
:
{
"taesd"
:
(
"TAESD"
,)
}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"sample"
CATEGORY
=
"sampling"
def
sample
(
self
,
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
):
return
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
)
def
sample
(
self
,
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
1.0
,
taesd
=
None
):
return
common_ksampler
(
model
,
seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
taesd
=
taesd
)
class
KSamplerAdvanced
:
@
classmethod
...
...
@@ -997,21 +1091,24 @@ class KSamplerAdvanced:
"start_at_step"
:
(
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
10000
}),
"end_at_step"
:
(
"INT"
,
{
"default"
:
10000
,
"min"
:
0
,
"max"
:
10000
}),
"return_with_leftover_noise"
:
([
"disable"
,
"enable"
],
),
}}
},
"optional"
:
{
"taesd"
:
(
"TAESD"
,)
}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"sample"
CATEGORY
=
"sampling"
def
sample
(
self
,
model
,
add_noise
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
start_at_step
,
end_at_step
,
return_with_leftover_noise
,
denoise
=
1.0
):
def
sample
(
self
,
model
,
add_noise
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
start_at_step
,
end_at_step
,
return_with_leftover_noise
,
denoise
=
1.0
,
taesd
=
None
):
force_full_denoise
=
True
if
return_with_leftover_noise
==
"enable"
:
force_full_denoise
=
False
disable_noise
=
False
if
add_noise
==
"disable"
:
disable_noise
=
True
return
common_ksampler
(
model
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_at_step
,
last_step
=
end_at_step
,
force_full_denoise
=
force_full_denoise
)
return
common_ksampler
(
model
,
noise_seed
,
steps
,
cfg
,
sampler_name
,
scheduler
,
positive
,
negative
,
latent_image
,
denoise
=
denoise
,
disable_noise
=
disable_noise
,
start_step
=
start_at_step
,
last_step
=
end_at_step
,
force_full_denoise
=
force_full_denoise
,
taesd
=
taesd
)
class
SaveImage
:
def
__init__
(
self
):
...
...
@@ -1270,6 +1367,9 @@ NODE_CLASS_MAPPINGS = {
"VAEEncode"
:
VAEEncode
,
"VAEEncodeForInpaint"
:
VAEEncodeForInpaint
,
"VAELoader"
:
VAELoader
,
"TAESDDecode"
:
TAESDDecode
,
"TAESDEncode"
:
TAESDEncode
,
"TAESDLoader"
:
TAESDLoader
,
"EmptyLatentImage"
:
EmptyLatentImage
,
"LatentUpscale"
:
LatentUpscale
,
"LatentUpscaleBy"
:
LatentUpscaleBy
,
...
...
@@ -1324,6 +1424,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoader"
:
"Load Checkpoint (With Config)"
,
"CheckpointLoaderSimple"
:
"Load Checkpoint"
,
"VAELoader"
:
"Load VAE"
,
"TAESDLoader"
:
"Load TAESD"
,
"LoraLoader"
:
"Load LoRA"
,
"CLIPLoader"
:
"Load CLIP"
,
"ControlNetLoader"
:
"Load ControlNet Model"
,
...
...
@@ -1346,6 +1447,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SetLatentNoiseMask"
:
"Set Latent Noise Mask"
,
"VAEDecode"
:
"VAE Decode"
,
"VAEEncode"
:
"VAE Encode"
,
"TAESDDecode"
:
"TAESD Decode"
,
"TAESDEncode"
:
"TAESD Encode"
,
"LatentRotate"
:
"Rotate Latent"
,
"LatentFlip"
:
"Flip Latent"
,
"LatentCrop"
:
"Crop Latent"
,
...
...
server.py
View file @
b4f434ee
...
...
@@ -7,6 +7,7 @@ import execution
import
uuid
import
json
import
glob
import
struct
from
PIL
import
Image
from
io
import
BytesIO
...
...
@@ -25,6 +26,11 @@ from comfy.cli_args import args
import
comfy.utils
import
comfy.model_management
class
BinaryEventTypes
:
PREVIEW_IMAGE
=
1
@
web
.
middleware
async
def
cache_control
(
request
:
web
.
Request
,
handler
):
response
:
web
.
Response
=
await
handler
(
request
)
...
...
@@ -457,16 +463,37 @@ class PromptServer():
return
prompt_info
async
def
send
(
self
,
event
,
data
,
sid
=
None
):
if
isinstance
(
data
,
(
bytes
,
bytearray
)):
await
self
.
send_bytes
(
event
,
data
,
sid
)
else
:
await
self
.
send_json
(
event
,
data
,
sid
)
def
encode_bytes
(
self
,
event
,
data
):
if
not
isinstance
(
event
,
int
):
raise
RuntimeError
(
f
"Binary event types must be integers, got
{
event
}
"
)
packed
=
struct
.
pack
(
">I"
,
event
)
message
=
bytearray
(
packed
)
message
.
extend
(
data
)
return
message
async
def
send_bytes
(
self
,
event
,
data
,
sid
=
None
):
message
=
self
.
encode_bytes
(
event
,
data
)
if
sid
is
None
:
for
ws
in
self
.
sockets
.
values
():
await
ws
.
send_bytes
(
message
)
elif
sid
in
self
.
sockets
:
await
self
.
sockets
[
sid
].
send_bytes
(
message
)
async
def
send_json
(
self
,
event
,
data
,
sid
=
None
):
message
=
{
"type"
:
event
,
"data"
:
data
}
if
isinstance
(
message
,
str
)
==
False
:
message
=
json
.
dumps
(
message
)
if
sid
is
None
:
for
ws
in
self
.
sockets
.
values
():
await
ws
.
send_
str
(
message
)
await
ws
.
send_
json
(
message
)
elif
sid
in
self
.
sockets
:
await
self
.
sockets
[
sid
].
send_
str
(
message
)
await
self
.
sockets
[
sid
].
send_
json
(
message
)
def
send_sync
(
self
,
event
,
data
,
sid
=
None
):
self
.
loop
.
call_soon_threadsafe
(
...
...
web/extensions/core/colorPalette.js
View file @
b4f434ee
...
...
@@ -21,6 +21,7 @@ const colorPalettes = {
"
MODEL
"
:
"
#B39DDB
"
,
// light lavender-purple
"
STYLE_MODEL
"
:
"
#C2FFAE
"
,
// light green-yellow
"
VAE
"
:
"
#FF6E6E
"
,
// bright red
"
TAESD
"
:
"
#DCC274
"
,
// cheesecake
},
"
litegraph_base
"
:
{
"
NODE_TITLE_COLOR
"
:
"
#999
"
,
...
...
web/scripts/api.js
View file @
b4f434ee
...
...
@@ -42,6 +42,7 @@ class ComfyApi extends EventTarget {
this
.
socket
=
new
WebSocket
(
`ws
${
window
.
location
.
protocol
===
"
https:
"
?
"
s
"
:
""
}
://
${
location
.
host
}
/ws
${
existingSession
}
`
);
this
.
socket
.
binaryType
=
"
arraybuffer
"
;
this
.
socket
.
addEventListener
(
"
open
"
,
()
=>
{
opened
=
true
;
...
...
@@ -70,39 +71,66 @@ class ComfyApi extends EventTarget {
this
.
socket
.
addEventListener
(
"
message
"
,
(
event
)
=>
{
try
{
const
msg
=
JSON
.
parse
(
event
.
data
);
switch
(
msg
.
type
)
{
case
"
status
"
:
if
(
msg
.
data
.
sid
)
{
this
.
clientId
=
msg
.
data
.
sid
;
window
.
name
=
this
.
clientId
;
if
(
event
.
data
instanceof
ArrayBuffer
)
{
const
view
=
new
DataView
(
event
.
data
);
const
eventType
=
view
.
getUint32
(
0
);
const
buffer
=
event
.
data
.
slice
(
4
);
console
.
error
(
"
BINARY
"
,
eventType
);
switch
(
eventType
)
{
case
1
:
const
view2
=
new
DataView
(
event
.
data
);
const
imageType
=
view2
.
getUint32
(
0
)
let
imageMime
switch
(
imageType
)
{
case
1
:
default
:
imageMime
=
"
image/jpeg
"
;
break
;
case
2
:
imageMime
=
"
image/png
"
}
this
.
dispatchEvent
(
new
CustomEvent
(
"
status
"
,
{
detail
:
msg
.
data
.
status
}));
break
;
case
"
progress
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
progress
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
executing
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executing
"
,
{
detail
:
msg
.
data
.
node
}));
break
;
case
"
executed
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executed
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_start
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_start
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_error
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_error
"
,
{
detail
:
msg
.
data
}));
const
jpegBlob
=
new
Blob
([
buffer
.
slice
(
4
)],
{
type
:
imageMime
});
this
.
dispatchEvent
(
new
CustomEvent
(
"
b_preview
"
,
{
detail
:
jpegBlob
}));
break
;
default
:
if
(
this
.
#
registered
.
has
(
msg
.
type
))
{
this
.
dispatchEvent
(
new
CustomEvent
(
msg
.
type
,
{
detail
:
msg
.
data
}));
}
else
{
throw
new
Error
(
"
Unknown message type
"
);
}
throw
new
Error
(
`Unknown binary websocket message of type
${
eventType
}
`
);
}
}
else
{
const
msg
=
JSON
.
parse
(
event
.
data
);
switch
(
msg
.
type
)
{
case
"
status
"
:
if
(
msg
.
data
.
sid
)
{
this
.
clientId
=
msg
.
data
.
sid
;
window
.
name
=
this
.
clientId
;
}
this
.
dispatchEvent
(
new
CustomEvent
(
"
status
"
,
{
detail
:
msg
.
data
.
status
}));
break
;
case
"
progress
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
progress
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
executing
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executing
"
,
{
detail
:
msg
.
data
.
node
}));
break
;
case
"
executed
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
executed
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_start
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_start
"
,
{
detail
:
msg
.
data
}));
break
;
case
"
execution_error
"
:
this
.
dispatchEvent
(
new
CustomEvent
(
"
execution_error
"
,
{
detail
:
msg
.
data
}));
break
;
default
:
if
(
this
.
#
registered
.
has
(
msg
.
type
))
{
this
.
dispatchEvent
(
new
CustomEvent
(
msg
.
type
,
{
detail
:
msg
.
data
}));
}
else
{
throw
new
Error
(
`Unknown message type
${
msg
.
type
}
`
);
}
}
}
}
catch
(
error
)
{
console
.
warn
(
"
Unhandled message:
"
,
event
.
data
);
console
.
warn
(
"
Unhandled message:
"
,
event
.
data
,
error
);
}
});
}
...
...
web/scripts/app.js
View file @
b4f434ee
...
...
@@ -44,6 +44,12 @@ export class ComfyApp {
*/
this
.
nodeOutputs
=
{};
/**
* Stores the preview image data for each node
* @type {Record<string, Image>}
*/
this
.
nodePreviewImages
=
{};
/**
* If the shift key on the keyboard is pressed
* @type {boolean}
...
...
@@ -367,29 +373,52 @@ export class ComfyApp {
node
.
prototype
.
onDrawBackground
=
function
(
ctx
)
{
if
(
!
this
.
flags
.
collapsed
)
{
let
imgURLs
=
[]
let
imagesChanged
=
false
const
output
=
app
.
nodeOutputs
[
this
.
id
+
""
];
if
(
output
&&
output
.
images
)
{
if
(
this
.
images
!==
output
.
images
)
{
this
.
images
=
output
.
images
;
this
.
imgs
=
null
;
this
.
imageIndex
=
null
;
imagesChanged
=
true
;
imgURLs
=
imgURLs
.
concat
(
output
.
images
.
map
(
params
=>
{
return
"
/view?
"
+
new
URLSearchParams
(
src
).
toString
()
+
app
.
getPreviewFormatParam
();
}))
}
}
const
preview
=
app
.
nodePreviewImages
[
this
.
id
+
""
]
if
(
this
.
preview
!==
preview
)
{
this
.
preview
=
preview
imagesChanged
=
true
;
if
(
preview
!=
null
)
{
imgURLs
.
push
(
preview
);
}
}
if
(
imagesChanged
)
{
this
.
imageIndex
=
null
;
if
(
imgURLs
.
length
>
0
)
{
Promise
.
all
(
output
.
image
s
.
map
((
src
)
=>
{
imgURL
s
.
map
((
src
)
=>
{
return
new
Promise
((
r
)
=>
{
const
img
=
new
Image
();
img
.
onload
=
()
=>
r
(
img
);
img
.
onerror
=
()
=>
r
(
null
);
img
.
src
=
"
/view?
"
+
new
URLSearchParams
(
src
).
toString
()
+
app
.
getPreviewFormatParam
();
img
.
src
=
src
});
})
).
then
((
imgs
)
=>
{
if
(
this
.
images
===
output
.
images
)
{
if
(
(
!
output
||
this
.
images
===
output
.
images
)
&&
(
!
preview
||
this
.
preview
===
preview
))
{
this
.
imgs
=
imgs
.
filter
(
Boolean
);
this
.
setSizeForImage
?.();
app
.
graph
.
setDirtyCanvas
(
true
);
}
});
}
else
{
this
.
imgs
=
null
;
}
}
if
(
this
.
imgs
&&
this
.
imgs
.
length
)
{
...
...
@@ -901,17 +930,20 @@ export class ComfyApp {
this
.
progress
=
null
;
this
.
runningNodeId
=
detail
;
this
.
graph
.
setDirtyCanvas
(
true
,
false
);
delete
this
.
nodePreviewImages
[
this
.
runningNodeId
]
});
api
.
addEventListener
(
"
executed
"
,
({
detail
})
=>
{
this
.
nodeOutputs
[
detail
.
node
]
=
detail
.
output
;
const
node
=
this
.
graph
.
getNodeById
(
detail
.
node
);
if
(
node
?.
onExecuted
)
{
node
.
onExecuted
(
detail
.
output
);
if
(
node
)
{
if
(
node
.
onExecuted
)
node
.
onExecuted
(
detail
.
output
);
}
});
api
.
addEventListener
(
"
execution_start
"
,
({
detail
})
=>
{
this
.
runningNodeId
=
null
;
this
.
lastExecutionError
=
null
});
...
...
@@ -922,6 +954,16 @@ export class ComfyApp {
this
.
canvas
.
draw
(
true
,
true
);
});
api
.
addEventListener
(
"
b_preview
"
,
({
detail
})
=>
{
const
id
=
this
.
runningNodeId
if
(
id
==
null
)
return
;
const
blob
=
detail
const
blobUrl
=
URL
.
createObjectURL
(
blob
)
this
.
nodePreviewImages
[
id
]
=
[
blobUrl
]
});
api
.
init
();
}
...
...
@@ -1465,8 +1507,10 @@ export class ComfyApp {
*/
clean
()
{
this
.
nodeOutputs
=
{};
this
.
nodePreviewImages
=
{}
this
.
lastPromptError
=
null
;
this
.
lastExecutionError
=
null
;
this
.
runningNodeId
=
null
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment