Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c92f3dca
Unverified
Commit
c92f3dca
authored
Dec 02, 2023
by
Jairo Correa
Committed by
GitHub
Dec 02, 2023
Browse files
Merge branch 'master' into image-cache
parents
006b24cc
2995a247
Changes
57
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1766 additions
and
84 deletions
+1766
-84
comfy/supported_models.py
comfy/supported_models.py
+38
-2
comfy/supported_models_base.py
comfy/supported_models_base.py
+7
-1
comfy/taesd/taesd.py
comfy/taesd/taesd.py
+14
-5
comfy/utils.py
comfy/utils.py
+10
-8
comfy_extras/nodes_custom_sampler.py
comfy_extras/nodes_custom_sampler.py
+52
-13
comfy_extras/nodes_images.py
comfy_extras/nodes_images.py
+175
-0
comfy_extras/nodes_latent.py
comfy_extras/nodes_latent.py
+36
-0
comfy_extras/nodes_model_advanced.py
comfy_extras/nodes_model_advanced.py
+48
-11
comfy_extras/nodes_model_downscale.py
comfy_extras/nodes_model_downscale.py
+53
-0
comfy_extras/nodes_video_model.py
comfy_extras/nodes_video_model.py
+89
-0
execution.py
execution.py
+19
-4
folder_paths.py
folder_paths.py
+10
-3
latent_preview.py
latent_preview.py
+1
-4
main.py
main.py
+30
-11
nodes.py
nodes.py
+76
-8
server.py
server.py
+5
-2
tests-ui/setup.js
tests-ui/setup.js
+1
-0
tests-ui/tests/extensions.test.js
tests-ui/tests/extensions.test.js
+196
-0
tests-ui/tests/groupNode.test.js
tests-ui/tests/groupNode.test.js
+818
-0
tests-ui/tests/widgetInputs.test.js
tests-ui/tests/widgetInputs.test.js
+88
-12
No files found.
comfy/supported_models.py
View file @
c92f3dca
...
...
@@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
False
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
unet_extra_config
=
{
...
...
@@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SD15
...
...
@@ -69,6 +71,10 @@ class SD20(supported_models_base.BASE):
return
model_base
.
ModelType
.
EPS
def
process_clip_state_dict
(
self
,
state_dict
):
replace_prefix
=
{}
replace_prefix
[
"conditioner.embedders.0.model."
]
=
"cond_stage_model.model."
#SD2 in sgm format
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.clip_h.transformer.text_model."
,
24
)
return
state_dict
...
...
@@ -88,6 +94,7 @@ class SD21UnclipL(SD20):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
1536
,
"use_temporal_attention"
:
False
,
}
clip_vision_prefix
=
"embedder.model.visual."
...
...
@@ -100,6 +107,7 @@ class SD21UnclipH(SD20):
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
2048
,
"use_temporal_attention"
:
False
,
}
clip_vision_prefix
=
"embedder.model.visual."
...
...
@@ -112,6 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim"
:
1280
,
"adm_in_channels"
:
2560
,
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -148,7 +157,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -203,8 +213,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
class
SVD_img2vid
(
supported_models_base
.
BASE
):
unet_config
=
{
"model_channels"
:
320
,
"in_channels"
:
8
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
"context_dim"
:
1024
,
"adm_in_channels"
:
768
,
"use_temporal_attention"
:
True
,
"use_temporal_resblock"
:
True
}
clip_vision_prefix
=
"conditioner.embedders.0.open_clip.model.visual."
latent_format
=
latent_formats
.
SD15
sampling_settings
=
{
"sigma_max"
:
700.0
,
"sigma_min"
:
0.002
}
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
SVD_img2vid
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
return
None
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
models
+=
[
SVD_img2vid
]
comfy/supported_models_base.py
View file @
c92f3dca
...
...
@@ -19,7 +19,7 @@ class BASE:
clip_prefix
=
[]
clip_vision_prefix
=
None
noise_aug_config
=
None
beta_schedule
=
"linear"
sampling_settings
=
{}
latent_format
=
latent_formats
.
LatentFormat
@
classmethod
...
...
@@ -53,6 +53,12 @@ class BASE:
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
def
process_unet_state_dict
(
self
,
state_dict
):
return
state_dict
def
process_vae_state_dict
(
self
,
state_dict
):
return
state_dict
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
replace_prefix
=
{
""
:
"cond_stage_model."
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
...
...
comfy/taesd/taesd.py
View file @
c92f3dca
...
...
@@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude
=
3
latent_shift
=
0.5
def
__init__
(
self
,
encoder_path
=
"taesd_encoder.pth"
,
decoder_path
=
"taesd_decoder.pth"
):
def
__init__
(
self
,
encoder_path
=
None
,
decoder_path
=
None
):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super
().
__init__
()
self
.
encoder
=
Encoder
()
self
.
decoder
=
Decoder
()
self
.
taesd_encoder
=
Encoder
()
self
.
taesd_decoder
=
Decoder
()
self
.
vae_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
))
if
encoder_path
is
not
None
:
self
.
encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
self
.
taesd_
encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
if
decoder_path
is
not
None
:
self
.
decoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
decoder_path
,
safe_load
=
True
))
self
.
taesd_
decoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
decoder_path
,
safe_load
=
True
))
@
staticmethod
def
scale_latents
(
x
):
...
...
@@ -65,3 +66,11 @@ class TAESD(nn.Module):
def
unscale_latents
(
x
):
"""[0, 1] -> raw latents"""
return
x
.
sub
(
TAESD
.
latent_shift
).
mul
(
2
*
TAESD
.
latent_magnitude
)
def
decode
(
self
,
x
):
x_sample
=
self
.
taesd_decoder
(
x
*
self
.
vae_scale
)
x_sample
=
x_sample
.
sub
(
0.5
).
mul
(
2
)
return
x_sample
def
encode
(
self
,
x
):
return
self
.
taesd_encoder
(
x
*
0.5
+
0.5
)
/
self
.
vae_scale
comfy/utils.py
View file @
c92f3dca
...
...
@@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
for
name
in
attrs
[:
-
1
]:
obj
=
getattr
(
obj
,
name
)
prev
=
getattr
(
obj
,
attrs
[
-
1
])
setattr
(
obj
,
attrs
[
-
1
],
torch
.
nn
.
Parameter
(
value
))
setattr
(
obj
,
attrs
[
-
1
],
torch
.
nn
.
Parameter
(
value
,
requires_grad
=
False
))
del
prev
def
copy_to_param
(
obj
,
attr
,
value
):
...
...
@@ -307,23 +307,25 @@ def bislerp(samples, width, height):
res
[
dot
<
1e-5
-
1
]
=
(
b1
*
(
1.0
-
r
)
+
b2
*
r
)[
dot
<
1e-5
-
1
]
return
res
def
generate_bilinear_data
(
length_old
,
length_new
):
coords_1
=
torch
.
arange
(
length_old
).
reshape
((
1
,
1
,
1
,
-
1
))
.
to
(
torch
.
float32
)
def
generate_bilinear_data
(
length_old
,
length_new
,
device
):
coords_1
=
torch
.
arange
(
length_old
,
dtype
=
torch
.
float32
,
device
=
device
).
reshape
((
1
,
1
,
1
,
-
1
))
coords_1
=
torch
.
nn
.
functional
.
interpolate
(
coords_1
,
size
=
(
1
,
length_new
),
mode
=
"bilinear"
)
ratios
=
coords_1
-
coords_1
.
floor
()
coords_1
=
coords_1
.
to
(
torch
.
int64
)
coords_2
=
torch
.
arange
(
length_old
).
reshape
((
1
,
1
,
1
,
-
1
))
.
to
(
torch
.
float32
)
+
1
coords_2
=
torch
.
arange
(
length_old
,
dtype
=
torch
.
float32
,
device
=
device
).
reshape
((
1
,
1
,
1
,
-
1
))
+
1
coords_2
[:,:,:,
-
1
]
-=
1
coords_2
=
torch
.
nn
.
functional
.
interpolate
(
coords_2
,
size
=
(
1
,
length_new
),
mode
=
"bilinear"
)
coords_2
=
coords_2
.
to
(
torch
.
int64
)
return
ratios
,
coords_1
,
coords_2
orig_dtype
=
samples
.
dtype
samples
=
samples
.
float
()
n
,
c
,
h
,
w
=
samples
.
shape
h_new
,
w_new
=
(
height
,
width
)
#linear w
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
w
,
w_new
)
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
w
,
w_new
,
samples
.
device
)
coords_1
=
coords_1
.
expand
((
n
,
c
,
h
,
-
1
))
coords_2
=
coords_2
.
expand
((
n
,
c
,
h
,
-
1
))
ratios
=
ratios
.
expand
((
n
,
1
,
h
,
-
1
))
...
...
@@ -336,7 +338,7 @@ def bislerp(samples, width, height):
result
=
result
.
reshape
(
n
,
h
,
w_new
,
c
).
movedim
(
-
1
,
1
)
#linear h
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
h
,
h_new
)
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
h
,
h_new
,
samples
.
device
)
coords_1
=
coords_1
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
c
,
-
1
,
w_new
))
coords_2
=
coords_2
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
c
,
-
1
,
w_new
))
ratios
=
ratios
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
1
,
-
1
,
w_new
))
...
...
@@ -347,7 +349,7 @@ def bislerp(samples, width, height):
result
=
slerp
(
pass_1
,
pass_2
,
ratios
)
result
=
result
.
reshape
(
n
,
h_new
,
w_new
,
c
).
movedim
(
-
1
,
1
)
return
result
return
result
.
to
(
orig_dtype
)
def
lanczos
(
samples
,
width
,
height
):
images
=
[
Image
.
fromarray
(
np
.
clip
(
255.
*
image
.
movedim
(
0
,
-
1
).
cpu
().
numpy
(),
0
,
255
).
astype
(
np
.
uint8
))
for
image
in
samples
]
...
...
comfy_extras/nodes_custom_sampler.py
View file @
c92f3dca
...
...
@@ -16,7 +16,7 @@ class BasicScheduler:
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
...
...
@@ -36,7 +36,7 @@ class KarrasScheduler:
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
...
...
@@ -54,7 +54,7 @@ class ExponentialScheduler:
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
...
...
@@ -73,7 +73,7 @@ class PolyexponentialScheduler:
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
...
...
@@ -81,6 +81,25 @@ class PolyexponentialScheduler:
sigmas
=
k_diffusion_sampling
.
get_sigmas_polyexponential
(
n
=
steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
rho
=
rho
)
return
(
sigmas
,
)
class
SDTurboScheduler
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"steps"
:
(
"INT"
,
{
"default"
:
1
,
"min"
:
1
,
"max"
:
10
}),
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling/schedulers"
FUNCTION
=
"get_sigmas"
def
get_sigmas
(
self
,
model
,
steps
):
timesteps
=
torch
.
flip
(
torch
.
arange
(
1
,
11
)
*
100
-
1
,
(
0
,))[:
steps
]
sigmas
=
model
.
model
.
model_sampling
.
sigma
(
timesteps
)
sigmas
=
torch
.
cat
([
sigmas
,
sigmas
.
new_zeros
([
1
])])
return
(
sigmas
,
)
class
VPScheduler
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -92,7 +111,7 @@ class VPScheduler:
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
...
...
@@ -109,7 +128,7 @@ class SplitSigmas:
}
}
RETURN_TYPES
=
(
"SIGMAS"
,
"SIGMAS"
)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/sigmas
"
FUNCTION
=
"get_sigmas"
...
...
@@ -118,6 +137,24 @@ class SplitSigmas:
sigmas2
=
sigmas
[
step
:]
return
(
sigmas1
,
sigmas2
)
class
FlipSigmas
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"sigmas"
:
(
"SIGMAS"
,
),
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling/sigmas"
FUNCTION
=
"get_sigmas"
def
get_sigmas
(
self
,
sigmas
):
sigmas
=
sigmas
.
flip
(
0
)
if
sigmas
[
0
]
==
0
:
sigmas
[
0
]
=
0.0001
return
(
sigmas
,)
class
KSamplerSelect
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -126,12 +163,12 @@ class KSamplerSelect:
}
}
RETURN_TYPES
=
(
"SAMPLER"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/samplers
"
FUNCTION
=
"get_sampler"
def
get_sampler
(
self
,
sampler_name
):
sampler
=
comfy
.
samplers
.
sampler_
class
(
sampler_name
)
()
sampler
=
comfy
.
samplers
.
sampler_
object
(
sampler_name
)
return
(
sampler
,
)
class
SamplerDPMPP_2M_SDE
:
...
...
@@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
}
}
RETURN_TYPES
=
(
"SAMPLER"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/samplers
"
FUNCTION
=
"get_sampler"
...
...
@@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name
=
"dpmpp_2m_sde"
else
:
sampler_name
=
"dpmpp_2m_sde_gpu"
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"solver_type"
:
solver_type
})
()
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"solver_type"
:
solver_type
})
return
(
sampler
,
)
...
...
@@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
}
}
RETURN_TYPES
=
(
"SAMPLER"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/samplers
"
FUNCTION
=
"get_sampler"
...
...
@@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
sampler_name
=
"dpmpp_sde"
else
:
sampler_name
=
"dpmpp_sde_gpu"
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"r"
:
r
})
()
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"r"
:
r
})
return
(
sampler
,
)
class
SamplerCustom
:
...
...
@@ -234,13 +271,15 @@ class SamplerCustom:
NODE_CLASS_MAPPINGS
=
{
"SamplerCustom"
:
SamplerCustom
,
"BasicScheduler"
:
BasicScheduler
,
"KarrasScheduler"
:
KarrasScheduler
,
"ExponentialScheduler"
:
ExponentialScheduler
,
"PolyexponentialScheduler"
:
PolyexponentialScheduler
,
"VPScheduler"
:
VPScheduler
,
"SDTurboScheduler"
:
SDTurboScheduler
,
"KSamplerSelect"
:
KSamplerSelect
,
"SamplerDPMPP_2M_SDE"
:
SamplerDPMPP_2M_SDE
,
"SamplerDPMPP_SDE"
:
SamplerDPMPP_SDE
,
"BasicScheduler"
:
BasicScheduler
,
"SplitSigmas"
:
SplitSigmas
,
"FlipSigmas"
:
FlipSigmas
,
}
comfy_extras/nodes_images.py
0 → 100644
View file @
c92f3dca
import
nodes
import
folder_paths
from
comfy.cli_args
import
args
from
PIL
import
Image
from
PIL.PngImagePlugin
import
PngInfo
import
numpy
as
np
import
json
import
os
MAX_RESOLUTION
=
nodes
.
MAX_RESOLUTION
class
ImageCrop
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"image"
:
(
"IMAGE"
,),
"width"
:
(
"INT"
,
{
"default"
:
512
,
"min"
:
1
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
"height"
:
(
"INT"
,
{
"default"
:
512
,
"min"
:
1
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
"x"
:
(
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
"y"
:
(
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
}}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"crop"
CATEGORY
=
"image/transform"
def
crop
(
self
,
image
,
width
,
height
,
x
,
y
):
x
=
min
(
x
,
image
.
shape
[
2
]
-
1
)
y
=
min
(
y
,
image
.
shape
[
1
]
-
1
)
to_x
=
width
+
x
to_y
=
height
+
y
img
=
image
[:,
y
:
to_y
,
x
:
to_x
,
:]
return
(
img
,)
class
RepeatImageBatch
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"image"
:
(
"IMAGE"
,),
"amount"
:
(
"INT"
,
{
"default"
:
1
,
"min"
:
1
,
"max"
:
64
}),
}}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"repeat"
CATEGORY
=
"image/batch"
def
repeat
(
self
,
image
,
amount
):
s
=
image
.
repeat
((
amount
,
1
,
1
,
1
))
return
(
s
,)
class
SaveAnimatedWEBP
:
def
__init__
(
self
):
self
.
output_dir
=
folder_paths
.
get_output_directory
()
self
.
type
=
"output"
self
.
prefix_append
=
""
methods
=
{
"default"
:
4
,
"fastest"
:
0
,
"slowest"
:
6
}
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"images"
:
(
"IMAGE"
,
),
"filename_prefix"
:
(
"STRING"
,
{
"default"
:
"ComfyUI"
}),
"fps"
:
(
"FLOAT"
,
{
"default"
:
6.0
,
"min"
:
0.01
,
"max"
:
1000.0
,
"step"
:
0.01
}),
"lossless"
:
(
"BOOLEAN"
,
{
"default"
:
True
}),
"quality"
:
(
"INT"
,
{
"default"
:
80
,
"min"
:
0
,
"max"
:
100
}),
"method"
:
(
list
(
s
.
methods
.
keys
()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden"
:
{
"prompt"
:
"PROMPT"
,
"extra_pnginfo"
:
"EXTRA_PNGINFO"
},
}
RETURN_TYPES
=
()
FUNCTION
=
"save_images"
OUTPUT_NODE
=
True
CATEGORY
=
"_for_testing"
def
save_images
(
self
,
images
,
fps
,
filename_prefix
,
lossless
,
quality
,
method
,
num_frames
=
0
,
prompt
=
None
,
extra_pnginfo
=
None
):
method
=
self
.
methods
.
get
(
method
)
filename_prefix
+=
self
.
prefix_append
full_output_folder
,
filename
,
counter
,
subfolder
,
filename_prefix
=
folder_paths
.
get_save_image_path
(
filename_prefix
,
self
.
output_dir
,
images
[
0
].
shape
[
1
],
images
[
0
].
shape
[
0
])
results
=
list
()
pil_images
=
[]
for
image
in
images
:
i
=
255.
*
image
.
cpu
().
numpy
()
img
=
Image
.
fromarray
(
np
.
clip
(
i
,
0
,
255
).
astype
(
np
.
uint8
))
pil_images
.
append
(
img
)
metadata
=
pil_images
[
0
].
getexif
()
if
not
args
.
disable_metadata
:
if
prompt
is
not
None
:
metadata
[
0x0110
]
=
"prompt:{}"
.
format
(
json
.
dumps
(
prompt
))
if
extra_pnginfo
is
not
None
:
inital_exif
=
0x010f
for
x
in
extra_pnginfo
:
metadata
[
inital_exif
]
=
"{}:{}"
.
format
(
x
,
json
.
dumps
(
extra_pnginfo
[
x
]))
inital_exif
-=
1
if
num_frames
==
0
:
num_frames
=
len
(
pil_images
)
c
=
len
(
pil_images
)
for
i
in
range
(
0
,
c
,
num_frames
):
file
=
f
"
{
filename
}
_
{
counter
:
05
}
_.webp"
pil_images
[
i
].
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
save_all
=
True
,
duration
=
int
(
1000.0
/
fps
),
append_images
=
pil_images
[
i
+
1
:
i
+
num_frames
],
exif
=
metadata
,
lossless
=
lossless
,
quality
=
quality
,
method
=
method
)
results
.
append
({
"filename"
:
file
,
"subfolder"
:
subfolder
,
"type"
:
self
.
type
})
counter
+=
1
animated
=
num_frames
!=
1
return
{
"ui"
:
{
"images"
:
results
,
"animated"
:
(
animated
,)
}
}
class
SaveAnimatedPNG
:
def
__init__
(
self
):
self
.
output_dir
=
folder_paths
.
get_output_directory
()
self
.
type
=
"output"
self
.
prefix_append
=
""
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"images"
:
(
"IMAGE"
,
),
"filename_prefix"
:
(
"STRING"
,
{
"default"
:
"ComfyUI"
}),
"fps"
:
(
"FLOAT"
,
{
"default"
:
6.0
,
"min"
:
0.01
,
"max"
:
1000.0
,
"step"
:
0.01
}),
"compress_level"
:
(
"INT"
,
{
"default"
:
4
,
"min"
:
0
,
"max"
:
9
})
},
"hidden"
:
{
"prompt"
:
"PROMPT"
,
"extra_pnginfo"
:
"EXTRA_PNGINFO"
},
}
RETURN_TYPES
=
()
FUNCTION
=
"save_images"
OUTPUT_NODE
=
True
CATEGORY
=
"_for_testing"
def
save_images
(
self
,
images
,
fps
,
compress_level
,
filename_prefix
=
"ComfyUI"
,
prompt
=
None
,
extra_pnginfo
=
None
):
filename_prefix
+=
self
.
prefix_append
full_output_folder
,
filename
,
counter
,
subfolder
,
filename_prefix
=
folder_paths
.
get_save_image_path
(
filename_prefix
,
self
.
output_dir
,
images
[
0
].
shape
[
1
],
images
[
0
].
shape
[
0
])
results
=
list
()
pil_images
=
[]
for
image
in
images
:
i
=
255.
*
image
.
cpu
().
numpy
()
img
=
Image
.
fromarray
(
np
.
clip
(
i
,
0
,
255
).
astype
(
np
.
uint8
))
pil_images
.
append
(
img
)
metadata
=
None
if
not
args
.
disable_metadata
:
metadata
=
PngInfo
()
if
prompt
is
not
None
:
metadata
.
add
(
b
"comf"
,
"prompt"
.
encode
(
"latin-1"
,
"strict"
)
+
b
"
\0
"
+
json
.
dumps
(
prompt
).
encode
(
"latin-1"
,
"strict"
),
after_idat
=
True
)
if
extra_pnginfo
is
not
None
:
for
x
in
extra_pnginfo
:
metadata
.
add
(
b
"comf"
,
x
.
encode
(
"latin-1"
,
"strict"
)
+
b
"
\0
"
+
json
.
dumps
(
extra_pnginfo
[
x
]).
encode
(
"latin-1"
,
"strict"
),
after_idat
=
True
)
file
=
f
"
{
filename
}
_
{
counter
:
05
}
_.png"
pil_images
[
0
].
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
pnginfo
=
metadata
,
compress_level
=
compress_level
,
save_all
=
True
,
duration
=
int
(
1000.0
/
fps
),
append_images
=
pil_images
[
1
:])
results
.
append
({
"filename"
:
file
,
"subfolder"
:
subfolder
,
"type"
:
self
.
type
})
return
{
"ui"
:
{
"images"
:
results
,
"animated"
:
(
True
,)}
}
NODE_CLASS_MAPPINGS
=
{
"ImageCrop"
:
ImageCrop
,
"RepeatImageBatch"
:
RepeatImageBatch
,
"SaveAnimatedWEBP"
:
SaveAnimatedWEBP
,
"SaveAnimatedPNG"
:
SaveAnimatedPNG
,
}
comfy_extras/nodes_latent.py
View file @
c92f3dca
import
comfy.utils
import
torch
def
reshape_latent_to
(
target_shape
,
latent
):
if
latent
.
shape
[
1
:]
!=
target_shape
[
1
:]:
...
...
@@ -67,8 +68,43 @@ class LatentMultiply:
samples_out
[
"samples"
]
=
s1
*
multiplier
return
(
samples_out
,)
class
LatentInterpolate
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"samples1"
:
(
"LATENT"
,),
"samples2"
:
(
"LATENT"
,),
"ratio"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.01
}),
}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"op"
CATEGORY
=
"latent/advanced"
def
op
(
self
,
samples1
,
samples2
,
ratio
):
samples_out
=
samples1
.
copy
()
s1
=
samples1
[
"samples"
]
s2
=
samples2
[
"samples"
]
s2
=
reshape_latent_to
(
s1
.
shape
,
s2
)
m1
=
torch
.
linalg
.
vector_norm
(
s1
,
dim
=
(
1
))
m2
=
torch
.
linalg
.
vector_norm
(
s2
,
dim
=
(
1
))
s1
=
torch
.
nan_to_num
(
s1
/
m1
)
s2
=
torch
.
nan_to_num
(
s2
/
m2
)
t
=
(
s1
*
ratio
+
s2
*
(
1.0
-
ratio
))
mt
=
torch
.
linalg
.
vector_norm
(
t
,
dim
=
(
1
))
st
=
torch
.
nan_to_num
(
t
/
mt
)
samples_out
[
"samples"
]
=
st
*
(
m1
*
ratio
+
m2
*
(
1.0
-
ratio
))
return
(
samples_out
,)
NODE_CLASS_MAPPINGS
=
{
"LatentAdd"
:
LatentAdd
,
"LatentSubtract"
:
LatentSubtract
,
"LatentMultiply"
:
LatentMultiply
,
"LatentInterpolate"
:
LatentInterpolate
,
}
comfy_extras/nodes_model_advanced.py
View file @
c92f3dca
...
...
@@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
return
c_out
*
x0
+
c_skip
*
model_input
class
ModelSamplingDiscreteLCM
(
torch
.
nn
.
Module
):
class
ModelSamplingDiscreteDistilled
(
torch
.
nn
.
Module
):
original_timesteps
=
50
def
__init__
(
self
):
super
().
__init__
()
self
.
sigma_data
=
1.0
...
...
@@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
original_timesteps
=
50
self
.
skip_steps
=
timesteps
//
original_timesteps
self
.
skip_steps
=
timesteps
//
self
.
original_timesteps
alphas_cumprod_valid
=
torch
.
zeros
((
original_timesteps
),
dtype
=
torch
.
float32
)
for
x
in
range
(
original_timesteps
):
alphas_cumprod_valid
[
original_timesteps
-
1
-
x
]
=
alphas_cumprod
[
timesteps
-
1
-
x
*
self
.
skip_steps
]
alphas_cumprod_valid
=
torch
.
zeros
((
self
.
original_timesteps
),
dtype
=
torch
.
float32
)
for
x
in
range
(
self
.
original_timesteps
):
alphas_cumprod_valid
[
self
.
original_timesteps
-
1
-
x
]
=
alphas_cumprod
[
timesteps
-
1
-
x
*
self
.
skip_steps
]
sigmas
=
((
1
-
alphas_cumprod_valid
)
/
alphas_cumprod_valid
)
**
0.5
self
.
set_sigmas
(
sigmas
)
...
...
@@ -55,18 +56,23 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
def
timestep
(
self
,
sigma
):
log_sigma
=
sigma
.
log
()
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
return
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
*
self
.
skip_steps
+
(
self
.
skip_steps
-
1
)
return
(
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
*
self
.
skip_steps
+
(
self
.
skip_steps
-
1
)
).
to
(
sigma
.
device
)
def
sigma
(
self
,
timestep
):
t
=
torch
.
clamp
(((
timestep
-
(
self
.
skip_steps
-
1
))
/
self
.
skip_steps
).
float
(),
min
=
0
,
max
=
(
len
(
self
.
sigmas
)
-
1
))
t
=
torch
.
clamp
(((
timestep
.
float
().
to
(
self
.
log_sigmas
.
device
)
-
(
self
.
skip_steps
-
1
))
/
self
.
skip_steps
).
float
(),
min
=
0
,
max
=
(
len
(
self
.
sigmas
)
-
1
))
low_idx
=
t
.
floor
().
long
()
high_idx
=
t
.
ceil
().
long
()
w
=
t
.
frac
()
log_sigma
=
(
1
-
w
)
*
self
.
log_sigmas
[
low_idx
]
+
w
*
self
.
log_sigmas
[
high_idx
]
return
log_sigma
.
exp
()
return
log_sigma
.
exp
()
.
to
(
timestep
.
device
)
def
percent_to_sigma
(
self
,
percent
):
return
self
.
sigma
(
torch
.
tensor
(
percent
*
999.0
))
if
percent
<=
0.0
:
return
999999999.9
if
percent
>=
1.0
:
return
0.0
percent
=
1.0
-
percent
return
self
.
sigma
(
torch
.
tensor
(
percent
*
999.0
)).
item
()
def
rescale_zero_terminal_snr_sigmas
(
sigmas
):
...
...
@@ -111,7 +117,7 @@ class ModelSamplingDiscrete:
sampling_type
=
comfy
.
model_sampling
.
V_PREDICTION
elif
sampling
==
"lcm"
:
sampling_type
=
LCM
sampling_base
=
ModelSamplingDiscrete
LCM
sampling_base
=
ModelSamplingDiscrete
Distilled
class
ModelSamplingAdvanced
(
sampling_base
,
sampling_type
):
pass
...
...
@@ -123,6 +129,36 @@ class ModelSamplingDiscrete:
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
class
ModelSamplingContinuousEDM
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"sampling"
:
([
"v_prediction"
,
"eps"
],),
"sigma_max"
:
(
"FLOAT"
,
{
"default"
:
120.0
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
"sigma_min"
:
(
"FLOAT"
,
{
"default"
:
0.002
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"advanced/model"
def
patch
(
self
,
model
,
sampling
,
sigma_max
,
sigma_min
):
m
=
model
.
clone
()
if
sampling
==
"eps"
:
sampling_type
=
comfy
.
model_sampling
.
EPS
elif
sampling
==
"v_prediction"
:
sampling_type
=
comfy
.
model_sampling
.
V_PREDICTION
class
ModelSamplingAdvanced
(
comfy
.
model_sampling
.
ModelSamplingContinuousEDM
,
sampling_type
):
pass
model_sampling
=
ModelSamplingAdvanced
()
model_sampling
.
set_sigma_range
(
sigma_min
,
sigma_max
)
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
class
RescaleCFG
:
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -164,5 +200,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS
=
{
"ModelSamplingDiscrete"
:
ModelSamplingDiscrete
,
"ModelSamplingContinuousEDM"
:
ModelSamplingContinuousEDM
,
"RescaleCFG"
:
RescaleCFG
,
}
comfy_extras/nodes_model_downscale.py
0 → 100644
View file @
c92f3dca
import
torch
import
comfy.utils
class
PatchModelAddDownscale
:
upscale_methods
=
[
"bicubic"
,
"nearest-exact"
,
"bilinear"
,
"area"
,
"bislerp"
]
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"block_number"
:
(
"INT"
,
{
"default"
:
3
,
"min"
:
1
,
"max"
:
32
,
"step"
:
1
}),
"downscale_factor"
:
(
"FLOAT"
,
{
"default"
:
2.0
,
"min"
:
0.1
,
"max"
:
9.0
,
"step"
:
0.001
}),
"start_percent"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end_percent"
:
(
"FLOAT"
,
{
"default"
:
0.35
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"downscale_after_skip"
:
(
"BOOLEAN"
,
{
"default"
:
True
}),
"downscale_method"
:
(
s
.
upscale_methods
,),
"upscale_method"
:
(
s
.
upscale_methods
,),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"_for_testing"
def
patch
(
self
,
model
,
block_number
,
downscale_factor
,
start_percent
,
end_percent
,
downscale_after_skip
,
downscale_method
,
upscale_method
):
sigma_start
=
model
.
model
.
model_sampling
.
percent_to_sigma
(
start_percent
)
sigma_end
=
model
.
model
.
model_sampling
.
percent_to_sigma
(
end_percent
)
def
input_block_patch
(
h
,
transformer_options
):
if
transformer_options
[
"block"
][
1
]
==
block_number
:
sigma
=
transformer_options
[
"sigmas"
][
0
].
item
()
if
sigma
<=
sigma_start
and
sigma
>=
sigma_end
:
h
=
comfy
.
utils
.
common_upscale
(
h
,
round
(
h
.
shape
[
-
1
]
*
(
1.0
/
downscale_factor
)),
round
(
h
.
shape
[
-
2
]
*
(
1.0
/
downscale_factor
)),
downscale_method
,
"disabled"
)
return
h
def
output_block_patch
(
h
,
hsp
,
transformer_options
):
if
h
.
shape
[
2
]
!=
hsp
.
shape
[
2
]:
h
=
comfy
.
utils
.
common_upscale
(
h
,
hsp
.
shape
[
-
1
],
hsp
.
shape
[
-
2
],
upscale_method
,
"disabled"
)
return
h
,
hsp
m
=
model
.
clone
()
if
downscale_after_skip
:
m
.
set_model_input_block_patch_after_skip
(
input_block_patch
)
else
:
m
.
set_model_input_block_patch
(
input_block_patch
)
m
.
set_model_output_block_patch
(
output_block_patch
)
return
(
m
,
)
NODE_CLASS_MAPPINGS
=
{
"PatchModelAddDownscale"
:
PatchModelAddDownscale
,
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
# Sampling
"PatchModelAddDownscale"
:
"PatchModelAddDownscale (Kohya Deep Shrink)"
,
}
comfy_extras/nodes_video_model.py
0 → 100644
View file @
c92f3dca
import
nodes
import
torch
import
comfy.utils
import
comfy.sd
import
folder_paths
class
ImageOnlyCheckpointLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"ckpt_name"
:
(
folder_paths
.
get_filename_list
(
"checkpoints"
),
),
}}
RETURN_TYPES
=
(
"MODEL"
,
"CLIP_VISION"
,
"VAE"
)
FUNCTION
=
"load_checkpoint"
CATEGORY
=
"loaders/video_models"
def
load_checkpoint
(
self
,
ckpt_name
,
output_vae
=
True
,
output_clip
=
True
):
ckpt_path
=
folder_paths
.
get_full_path
(
"checkpoints"
,
ckpt_name
)
out
=
comfy
.
sd
.
load_checkpoint_guess_config
(
ckpt_path
,
output_vae
=
True
,
output_clip
=
False
,
output_clipvision
=
True
,
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
return
(
out
[
0
],
out
[
3
],
out
[
2
])
class
SVD_img2vid_Conditioning
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"clip_vision"
:
(
"CLIP_VISION"
,),
"init_image"
:
(
"IMAGE"
,),
"vae"
:
(
"VAE"
,),
"width"
:
(
"INT"
,
{
"default"
:
1024
,
"min"
:
16
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"height"
:
(
"INT"
,
{
"default"
:
576
,
"min"
:
16
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"video_frames"
:
(
"INT"
,
{
"default"
:
14
,
"min"
:
1
,
"max"
:
4096
}),
"motion_bucket_id"
:
(
"INT"
,
{
"default"
:
127
,
"min"
:
1
,
"max"
:
1023
}),
"fps"
:
(
"INT"
,
{
"default"
:
6
,
"min"
:
1
,
"max"
:
1024
}),
"augmentation_level"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
})
}}
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
,
"LATENT"
)
RETURN_NAMES
=
(
"positive"
,
"negative"
,
"latent"
)
FUNCTION
=
"encode"
CATEGORY
=
"conditioning/video_models"
def
encode
(
self
,
clip_vision
,
init_image
,
vae
,
width
,
height
,
video_frames
,
motion_bucket_id
,
fps
,
augmentation_level
):
output
=
clip_vision
.
encode_image
(
init_image
)
pooled
=
output
.
image_embeds
.
unsqueeze
(
0
)
pixels
=
comfy
.
utils
.
common_upscale
(
init_image
.
movedim
(
-
1
,
1
),
width
,
height
,
"bilinear"
,
"center"
).
movedim
(
1
,
-
1
)
encode_pixels
=
pixels
[:,:,:,:
3
]
if
augmentation_level
>
0
:
encode_pixels
+=
torch
.
randn_like
(
pixels
)
*
augmentation_level
t
=
vae
.
encode
(
encode_pixels
)
positive
=
[[
pooled
,
{
"motion_bucket_id"
:
motion_bucket_id
,
"fps"
:
fps
,
"augmentation_level"
:
augmentation_level
,
"concat_latent_image"
:
t
}]]
negative
=
[[
torch
.
zeros_like
(
pooled
),
{
"motion_bucket_id"
:
motion_bucket_id
,
"fps"
:
fps
,
"augmentation_level"
:
augmentation_level
,
"concat_latent_image"
:
torch
.
zeros_like
(
t
)}]]
latent
=
torch
.
zeros
([
video_frames
,
4
,
height
//
8
,
width
//
8
])
return
(
positive
,
negative
,
{
"samples"
:
latent
})
class
VideoLinearCFGGuidance
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"min_cfg"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
100.0
,
"step"
:
0.5
,
"round"
:
0.01
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"sampling/video_models"
def
patch
(
self
,
model
,
min_cfg
):
def
linear_cfg
(
args
):
cond
=
args
[
"cond"
]
uncond
=
args
[
"uncond"
]
cond_scale
=
args
[
"cond_scale"
]
scale
=
torch
.
linspace
(
min_cfg
,
cond_scale
,
cond
.
shape
[
0
],
device
=
cond
.
device
).
reshape
((
cond
.
shape
[
0
],
1
,
1
,
1
))
return
uncond
+
scale
*
(
cond
-
uncond
)
m
=
model
.
clone
()
m
.
set_model_sampler_cfg_function
(
linear_cfg
)
return
(
m
,
)
NODE_CLASS_MAPPINGS
=
{
"ImageOnlyCheckpointLoader"
:
ImageOnlyCheckpointLoader
,
"SVD_img2vid_Conditioning"
:
SVD_img2vid_Conditioning
,
"VideoLinearCFGGuidance"
:
VideoLinearCFGGuidance
,
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
"ImageOnlyCheckpointLoader"
:
"Image Only Checkpoint Loader (img2vid model)"
,
}
execution.py
View file @
c92f3dca
...
...
@@ -681,6 +681,7 @@ def validate_prompt(prompt):
return
(
True
,
None
,
list
(
good_outputs
),
node_errors
)
MAXIMUM_HISTORY_SIZE
=
10000
class
PromptQueue
:
def
__init__
(
self
,
server
):
...
...
@@ -699,10 +700,12 @@ class PromptQueue:
self
.
server
.
queue_updated
()
self
.
not_empty
.
notify
()
def
get
(
self
):
def
get
(
self
,
timeout
=
None
):
with
self
.
not_empty
:
while
len
(
self
.
queue
)
==
0
:
self
.
not_empty
.
wait
()
self
.
not_empty
.
wait
(
timeout
=
timeout
)
if
timeout
is
not
None
and
len
(
self
.
queue
)
==
0
:
return
None
item
=
heapq
.
heappop
(
self
.
queue
)
i
=
self
.
task_counter
self
.
currently_running
[
i
]
=
copy
.
deepcopy
(
item
)
...
...
@@ -713,6 +716,8 @@ class PromptQueue:
def
task_done
(
self
,
item_id
,
outputs
):
with
self
.
mutex
:
prompt
=
self
.
currently_running
.
pop
(
item_id
)
if
len
(
self
.
history
)
>
MAXIMUM_HISTORY_SIZE
:
self
.
history
.
pop
(
next
(
iter
(
self
.
history
)))
self
.
history
[
prompt
[
1
]]
=
{
"prompt"
:
prompt
,
"outputs"
:
{}
}
for
o
in
outputs
:
self
.
history
[
prompt
[
1
]][
"outputs"
][
o
]
=
outputs
[
o
]
...
...
@@ -747,10 +752,20 @@ class PromptQueue:
return
True
return
False
def
get_history
(
self
,
prompt_id
=
None
):
def
get_history
(
self
,
prompt_id
=
None
,
max_items
=
None
,
offset
=-
1
):
with
self
.
mutex
:
if
prompt_id
is
None
:
return
copy
.
deepcopy
(
self
.
history
)
out
=
{}
i
=
0
if
offset
<
0
and
max_items
is
not
None
:
offset
=
len
(
self
.
history
)
-
max_items
for
k
in
self
.
history
:
if
i
>=
offset
:
out
[
k
]
=
self
.
history
[
k
]
if
max_items
is
not
None
and
len
(
out
)
>=
max_items
:
break
i
+=
1
return
out
elif
prompt_id
in
self
.
history
:
return
{
prompt_id
:
copy
.
deepcopy
(
self
.
history
[
prompt_id
])}
else
:
...
...
folder_paths.py
View file @
c92f3dca
...
...
@@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
filename_list_cache
=
{}
if
not
os
.
path
.
exists
(
input_directory
):
os
.
makedirs
(
input_directory
)
try
:
os
.
makedirs
(
input_directory
)
except
:
print
(
"Failed to create input directory"
)
def
set_output_directory
(
output_dir
):
global
output_directory
...
...
@@ -228,8 +231,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder
=
os
.
path
.
join
(
output_dir
,
subfolder
)
if
os
.
path
.
commonpath
((
output_dir
,
os
.
path
.
abspath
(
full_output_folder
)))
!=
output_dir
:
print
(
"Saving image outside the output folder is not allowed."
)
return
{}
err
=
"**** ERROR: Saving image outside the output folder is not allowed."
+
\
"
\n
full_output_folder: "
+
os
.
path
.
abspath
(
full_output_folder
)
+
\
"
\n
output_dir: "
+
output_dir
+
\
"
\n
commonpath: "
+
os
.
path
.
commonpath
((
output_dir
,
os
.
path
.
abspath
(
full_output_folder
)))
print
(
err
)
raise
Exception
(
err
)
try
:
counter
=
max
(
filter
(
lambda
a
:
a
[
1
][:
-
1
]
==
filename
and
a
[
1
][
-
1
]
==
"_"
,
map
(
map_filename
,
os
.
listdir
(
full_output_folder
))))[
0
]
+
1
...
...
latent_preview.py
View file @
c92f3dca
...
...
@@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self
.
taesd
=
taesd
def
decode_latent_to_preview
(
self
,
x0
):
x_sample
=
self
.
taesd
.
decoder
(
x0
[:
1
])[
0
].
detach
()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample
=
x_sample
.
sub
(
0.5
).
mul
(
2
)
x_sample
=
self
.
taesd
.
decode
(
x0
[:
1
])[
0
].
detach
()
x_sample
=
torch
.
clamp
((
x_sample
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_sample
=
255.
*
np
.
moveaxis
(
x_sample
.
cpu
().
numpy
(),
0
,
2
)
x_sample
=
x_sample
.
astype
(
np
.
uint8
)
...
...
main.py
View file @
c92f3dca
...
...
@@ -88,18 +88,37 @@ def cuda_malloc_warning():
def
prompt_worker
(
q
,
server
):
e
=
execution
.
PromptExecutor
(
server
)
last_gc_collect
=
0
need_gc
=
False
gc_collect_interval
=
10.0
while
True
:
item
,
item_id
=
q
.
get
()
execution_start_time
=
time
.
perf_counter
()
prompt_id
=
item
[
1
]
e
.
execute
(
item
[
2
],
prompt_id
,
item
[
3
],
item
[
4
])
q
.
task_done
(
item_id
,
e
.
outputs_ui
)
if
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executing"
,
{
"node"
:
None
,
"prompt_id"
:
prompt_id
},
server
.
client_id
)
print
(
"Prompt executed in {:.2f} seconds"
.
format
(
time
.
perf_counter
()
-
execution_start_time
))
gc
.
collect
()
comfy
.
model_management
.
soft_empty_cache
()
timeout
=
None
if
need_gc
:
timeout
=
max
(
gc_collect_interval
-
(
current_time
-
last_gc_collect
),
0.0
)
queue_item
=
q
.
get
(
timeout
=
timeout
)
if
queue_item
is
not
None
:
item
,
item_id
=
queue_item
execution_start_time
=
time
.
perf_counter
()
prompt_id
=
item
[
1
]
e
.
execute
(
item
[
2
],
prompt_id
,
item
[
3
],
item
[
4
])
need_gc
=
True
q
.
task_done
(
item_id
,
e
.
outputs_ui
)
if
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executing"
,
{
"node"
:
None
,
"prompt_id"
:
prompt_id
},
server
.
client_id
)
current_time
=
time
.
perf_counter
()
execution_time
=
current_time
-
execution_start_time
print
(
"Prompt executed in {:.2f} seconds"
.
format
(
execution_time
))
if
need_gc
:
current_time
=
time
.
perf_counter
()
if
(
current_time
-
last_gc_collect
)
>
gc_collect_interval
:
gc
.
collect
()
comfy
.
model_management
.
soft_empty_cache
()
last_gc_collect
=
current_time
need_gc
=
False
async
def
run
(
server
,
address
=
''
,
port
=
8188
,
verbose
=
True
,
call_on_start
=
None
):
await
asyncio
.
gather
(
server
.
start
(
address
,
port
,
verbose
,
call_on_start
),
server
.
publish_loop
())
...
...
nodes.py
View file @
c92f3dca
...
...
@@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
c
=
[]
for
t
in
conditioning
:
d
=
t
[
1
].
copy
()
d
[
'start_percent'
]
=
1.0
-
start
d
[
'end_percent'
]
=
1.0
-
end
d
[
'start_percent'
]
=
start
d
[
'end_percent'
]
=
end
n
=
[
t
[
0
],
d
]
c
.
append
(
n
)
return
(
c
,
)
...
...
@@ -572,10 +572,69 @@ class LoraLoader:
model_lora
,
clip_lora
=
comfy
.
sd
.
load_lora_for_models
(
model
,
clip
,
lora
,
strength_model
,
strength_clip
)
return
(
model_lora
,
clip_lora
)
class
LoraLoaderModelOnly
(
LoraLoader
):
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"lora_name"
:
(
folder_paths
.
get_filename_list
(
"loras"
),
),
"strength_model"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
-
20.0
,
"max"
:
20.0
,
"step"
:
0.01
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"load_lora_model_only"
def
load_lora_model_only
(
self
,
model
,
lora_name
,
strength_model
):
return
(
self
.
load_lora
(
model
,
None
,
lora_name
,
strength_model
,
0
)[
0
],)
class
VAELoader
:
@
staticmethod
def
vae_list
():
vaes
=
folder_paths
.
get_filename_list
(
"vae"
)
approx_vaes
=
folder_paths
.
get_filename_list
(
"vae_approx"
)
sdxl_taesd_enc
=
False
sdxl_taesd_dec
=
False
sd1_taesd_enc
=
False
sd1_taesd_dec
=
False
for
v
in
approx_vaes
:
if
v
.
startswith
(
"taesd_decoder."
):
sd1_taesd_dec
=
True
elif
v
.
startswith
(
"taesd_encoder."
):
sd1_taesd_enc
=
True
elif
v
.
startswith
(
"taesdxl_decoder."
):
sdxl_taesd_dec
=
True
elif
v
.
startswith
(
"taesdxl_encoder."
):
sdxl_taesd_enc
=
True
if
sd1_taesd_dec
and
sd1_taesd_enc
:
vaes
.
append
(
"taesd"
)
if
sdxl_taesd_dec
and
sdxl_taesd_enc
:
vaes
.
append
(
"taesdxl"
)
return
vaes
@
staticmethod
def
load_taesd
(
name
):
sd
=
{}
approx_vaes
=
folder_paths
.
get_filename_list
(
"vae_approx"
)
encoder
=
next
(
filter
(
lambda
a
:
a
.
startswith
(
"{}_encoder."
.
format
(
name
)),
approx_vaes
))
decoder
=
next
(
filter
(
lambda
a
:
a
.
startswith
(
"{}_decoder."
.
format
(
name
)),
approx_vaes
))
enc
=
comfy
.
utils
.
load_torch_file
(
folder_paths
.
get_full_path
(
"vae_approx"
,
encoder
))
for
k
in
enc
:
sd
[
"taesd_encoder.{}"
.
format
(
k
)]
=
enc
[
k
]
dec
=
comfy
.
utils
.
load_torch_file
(
folder_paths
.
get_full_path
(
"vae_approx"
,
decoder
))
for
k
in
dec
:
sd
[
"taesd_decoder.{}"
.
format
(
k
)]
=
dec
[
k
]
if
name
==
"taesd"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.18215
)
elif
name
==
"taesdxl"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.13025
)
return
sd
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"vae_name"
:
(
folder_paths
.
get_filenam
e_list
(
"vae"
),
)}}
return
{
"required"
:
{
"vae_name"
:
(
s
.
va
e_list
(),
)}}
RETURN_TYPES
=
(
"VAE"
,)
FUNCTION
=
"load_vae"
...
...
@@ -583,8 +642,11 @@ class VAELoader:
#TODO: scale factor?
def
load_vae
(
self
,
vae_name
):
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
if
vae_name
in
[
"taesd"
,
"taesdxl"
]:
sd
=
self
.
load_taesd
(
vae_name
)
else
:
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
)
return
(
vae
,)
...
...
@@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
1.0
-
start_percent
,
1.0
-
end_percent
))
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
))
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
...
...
@@ -1275,6 +1337,7 @@ class SaveImage:
self
.
output_dir
=
folder_paths
.
get_output_directory
()
self
.
type
=
"output"
self
.
prefix_append
=
""
self
.
compress_level
=
4
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -1308,7 +1371,7 @@ class SaveImage:
metadata
.
add_text
(
x
,
json
.
dumps
(
extra_pnginfo
[
x
]))
file
=
f
"
{
filename
}
_
{
counter
:
05
}
_.png"
img
.
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
pnginfo
=
metadata
,
compress_level
=
4
)
img
.
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
pnginfo
=
metadata
,
compress_level
=
self
.
compress_level
)
results
.
append
({
"filename"
:
file
,
"subfolder"
:
subfolder
,
...
...
@@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
self
.
output_dir
=
folder_paths
.
get_temp_directory
()
self
.
type
=
"temp"
self
.
prefix_append
=
"_temp_"
+
''
.
join
(
random
.
choice
(
"abcdefghijklmnopqrstupvxyz"
)
for
x
in
range
(
5
))
self
.
compress_level
=
1
@
classmethod
def
INPUT_TYPES
(
s
):
...
...
@@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut"
:
ConditioningZeroOut
,
"ConditioningSetTimestepRange"
:
ConditioningSetTimestepRange
,
"LoraLoaderModelOnly"
:
LoraLoaderModelOnly
,
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
...
...
@@ -1759,7 +1824,7 @@ def load_custom_nodes():
node_paths
=
folder_paths
.
get_folder_paths
(
"custom_nodes"
)
node_import_times
=
[]
for
custom_node_path
in
node_paths
:
possible_modules
=
os
.
listdir
(
custom_node_path
)
possible_modules
=
os
.
listdir
(
os
.
path
.
realpath
(
custom_node_path
)
)
if
"__pycache__"
in
possible_modules
:
possible_modules
.
remove
(
"__pycache__"
)
...
...
@@ -1799,6 +1864,9 @@ def init_custom_nodes():
"nodes_custom_sampler.py"
,
"nodes_hypertile.py"
,
"nodes_model_advanced.py"
,
"nodes_model_downscale.py"
,
"nodes_images.py"
,
"nodes_video_model.py"
,
]
for
node_file
in
extras_files
:
...
...
server.py
View file @
c92f3dca
...
...
@@ -431,7 +431,10 @@ class PromptServer():
@
routes
.
get
(
"/history"
)
async
def
get_history
(
request
):
return
web
.
json_response
(
self
.
prompt_queue
.
get_history
())
max_items
=
request
.
rel_url
.
query
.
get
(
"max_items"
,
None
)
if
max_items
is
not
None
:
max_items
=
int
(
max_items
)
return
web
.
json_response
(
self
.
prompt_queue
.
get_history
(
max_items
=
max_items
))
@
routes
.
get
(
"/history/{prompt_id}"
)
async
def
get_history
(
request
):
...
...
@@ -573,7 +576,7 @@ class PromptServer():
bytesIO
=
BytesIO
()
header
=
struct
.
pack
(
">I"
,
type_num
)
bytesIO
.
write
(
header
)
image
.
save
(
bytesIO
,
format
=
image_type
,
quality
=
95
,
compress_level
=
4
)
image
.
save
(
bytesIO
,
format
=
image_type
,
quality
=
95
,
compress_level
=
1
)
preview_bytes
=
bytesIO
.
getvalue
()
await
self
.
send_bytes
(
BinaryEventTypes
.
PREVIEW_IMAGE
,
preview_bytes
,
sid
=
sid
)
...
...
tests-ui/setup.js
View file @
c92f3dca
...
...
@@ -20,6 +20,7 @@ async function setup() {
// Modify the response data to add some checkpoints
const
objectInfo
=
JSON
.
parse
(
data
);
objectInfo
.
CheckpointLoaderSimple
.
input
.
required
.
ckpt_name
[
0
]
=
[
"
model1.safetensors
"
,
"
model2.ckpt
"
];
objectInfo
.
VAELoader
.
input
.
required
.
vae_name
[
0
]
=
[
"
vae1.safetensors
"
,
"
vae2.ckpt
"
];
data
=
JSON
.
stringify
(
objectInfo
,
undefined
,
"
\t
"
);
...
...
tests-ui/tests/extensions.test.js
0 → 100644
View file @
c92f3dca
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const
{
start
}
=
require
(
"
../utils
"
);
const
lg
=
require
(
"
../utils/litegraph
"
);
describe
(
"
extensions
"
,
()
=>
{
beforeEach
(()
=>
{
lg
.
setup
(
global
);
});
afterEach
(()
=>
{
lg
.
teardown
(
global
);
});
it
(
"
calls each extension hook
"
,
async
()
=>
{
const
mockExtension
=
{
name
:
"
TestExtension
"
,
init
:
jest
.
fn
(),
setup
:
jest
.
fn
(),
addCustomNodeDefs
:
jest
.
fn
(),
getCustomWidgets
:
jest
.
fn
(),
beforeRegisterNodeDef
:
jest
.
fn
(),
registerCustomNodes
:
jest
.
fn
(),
loadedGraphNode
:
jest
.
fn
(),
nodeCreated
:
jest
.
fn
(),
beforeConfigureGraph
:
jest
.
fn
(),
afterConfigureGraph
:
jest
.
fn
(),
};
const
{
app
,
ez
,
graph
}
=
await
start
({
async
preSetup
(
app
)
{
app
.
registerExtension
(
mockExtension
);
},
});
// Basic initialisation hooks should be called once, with app
expect
(
mockExtension
.
init
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
init
).
toHaveBeenCalledWith
(
app
);
// Adding custom node defs should be passed the full list of nodes
expect
(
mockExtension
.
addCustomNodeDefs
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
addCustomNodeDefs
.
mock
.
calls
[
0
][
1
]).
toStrictEqual
(
app
);
const
defs
=
mockExtension
.
addCustomNodeDefs
.
mock
.
calls
[
0
][
0
];
expect
(
defs
).
toHaveProperty
(
"
KSampler
"
);
expect
(
defs
).
toHaveProperty
(
"
LoadImage
"
);
// Get custom widgets is called once and should return new widget types
expect
(
mockExtension
.
getCustomWidgets
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
getCustomWidgets
).
toHaveBeenCalledWith
(
app
);
// Before register node def will be called once per node type
const
nodeNames
=
Object
.
keys
(
defs
);
const
nodeCount
=
nodeNames
.
length
;
expect
(
mockExtension
.
beforeRegisterNodeDef
).
toHaveBeenCalledTimes
(
nodeCount
);
for
(
let
i
=
0
;
i
<
nodeCount
;
i
++
)
{
// It should be send the JS class and the original JSON definition
const
nodeClass
=
mockExtension
.
beforeRegisterNodeDef
.
mock
.
calls
[
i
][
0
];
const
nodeDef
=
mockExtension
.
beforeRegisterNodeDef
.
mock
.
calls
[
i
][
1
];
expect
(
nodeClass
.
name
).
toBe
(
"
ComfyNode
"
);
expect
(
nodeClass
.
comfyClass
).
toBe
(
nodeNames
[
i
]);
expect
(
nodeDef
.
name
).
toBe
(
nodeNames
[
i
]);
expect
(
nodeDef
).
toHaveProperty
(
"
input
"
);
expect
(
nodeDef
).
toHaveProperty
(
"
output
"
);
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect
(
mockExtension
.
registerCustomNodes
).
toHaveBeenCalledTimes
(
1
);
// Before configure graph will be called here as the default graph is being loaded
expect
(
mockExtension
.
beforeConfigureGraph
).
toHaveBeenCalledTimes
(
1
);
// it gets sent the graph data that is going to be loaded
const
graphData
=
mockExtension
.
beforeConfigureGraph
.
mock
.
calls
[
0
][
0
];
// A node created is fired for each node constructor that is called
expect
(
mockExtension
.
nodeCreated
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
);
for
(
let
i
=
0
;
i
<
graphData
.
nodes
.
length
;
i
++
)
{
expect
(
mockExtension
.
nodeCreated
.
mock
.
calls
[
i
][
0
].
type
).
toBe
(
graphData
.
nodes
[
i
].
type
);
}
// Each node then calls loadedGraphNode to allow them to be updated
expect
(
mockExtension
.
loadedGraphNode
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
);
for
(
let
i
=
0
;
i
<
graphData
.
nodes
.
length
;
i
++
)
{
expect
(
mockExtension
.
loadedGraphNode
.
mock
.
calls
[
i
][
0
].
type
).
toBe
(
graphData
.
nodes
[
i
].
type
);
}
// After configure is then called once all the setup is done
expect
(
mockExtension
.
afterConfigureGraph
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
setup
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
setup
).
toHaveBeenCalledWith
(
app
);
// Ensure hooks are called in the correct order
const
callOrder
=
[
"
init
"
,
"
addCustomNodeDefs
"
,
"
getCustomWidgets
"
,
"
beforeRegisterNodeDef
"
,
"
registerCustomNodes
"
,
"
beforeConfigureGraph
"
,
"
nodeCreated
"
,
"
loadedGraphNode
"
,
"
afterConfigureGraph
"
,
"
setup
"
,
];
for
(
let
i
=
1
;
i
<
callOrder
.
length
;
i
++
)
{
const
fn1
=
mockExtension
[
callOrder
[
i
-
1
]];
const
fn2
=
mockExtension
[
callOrder
[
i
]];
expect
(
fn1
.
mock
.
invocationCallOrder
[
0
]).
toBeLessThan
(
fn2
.
mock
.
invocationCallOrder
[
0
]);
}
graph
.
clear
();
// Ensure adding a new node calls the correct callback
ez
.
LoadImage
();
expect
(
mockExtension
.
loadedGraphNode
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
);
expect
(
mockExtension
.
nodeCreated
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
+
1
);
expect
(
mockExtension
.
nodeCreated
.
mock
.
lastCall
[
0
].
type
).
toBe
(
"
LoadImage
"
);
// Reload the graph to ensure correct hooks are fired
await
graph
.
reload
();
// These hooks should not be fired again
expect
(
mockExtension
.
init
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
addCustomNodeDefs
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
getCustomWidgets
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
registerCustomNodes
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
beforeRegisterNodeDef
).
toHaveBeenCalledTimes
(
nodeCount
);
expect
(
mockExtension
.
setup
).
toHaveBeenCalledTimes
(
1
);
// These should be called again
expect
(
mockExtension
.
beforeConfigureGraph
).
toHaveBeenCalledTimes
(
2
);
expect
(
mockExtension
.
nodeCreated
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
+
2
);
expect
(
mockExtension
.
loadedGraphNode
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
+
1
);
expect
(
mockExtension
.
afterConfigureGraph
).
toHaveBeenCalledTimes
(
2
);
});
it
(
"
allows custom nodeDefs and widgets to be registered
"
,
async
()
=>
{
const
widgetMock
=
jest
.
fn
((
node
,
inputName
,
inputData
,
app
)
=>
{
expect
(
node
.
constructor
.
comfyClass
).
toBe
(
"
TestNode
"
);
expect
(
inputName
).
toBe
(
"
test_input
"
);
expect
(
inputData
[
0
]).
toBe
(
"
CUSTOMWIDGET
"
);
expect
(
inputData
[
1
]?.
hello
).
toBe
(
"
world
"
);
expect
(
app
).
toStrictEqual
(
app
);
return
{
widget
:
node
.
addWidget
(
"
button
"
,
inputName
,
"
hello
"
,
()
=>
{}),
};
});
// Register our extension that adds a custom node + widget type
const
mockExtension
=
{
name
:
"
TestExtension
"
,
addCustomNodeDefs
:
(
nodeDefs
)
=>
{
nodeDefs
[
"
TestNode
"
]
=
{
output
:
[],
output_name
:
[],
output_is_list
:
[],
name
:
"
TestNode
"
,
display_name
:
"
TestNode
"
,
category
:
"
Test
"
,
input
:
{
required
:
{
test_input
:
[
"
CUSTOMWIDGET
"
,
{
hello
:
"
world
"
}],
},
},
};
},
getCustomWidgets
:
jest
.
fn
(()
=>
{
return
{
CUSTOMWIDGET
:
widgetMock
,
};
}),
};
const
{
graph
,
ez
}
=
await
start
({
async
preSetup
(
app
)
{
app
.
registerExtension
(
mockExtension
);
},
});
expect
(
mockExtension
.
getCustomWidgets
).
toBeCalledTimes
(
1
);
graph
.
clear
();
expect
(
widgetMock
).
toBeCalledTimes
(
0
);
const
node
=
ez
.
TestNode
();
expect
(
widgetMock
).
toBeCalledTimes
(
1
);
// Ensure our custom widget is created
expect
(
node
.
inputs
.
length
).
toBe
(
0
);
expect
(
node
.
widgets
.
length
).
toBe
(
1
);
const
w
=
node
.
widgets
[
0
].
widget
;
expect
(
w
.
name
).
toBe
(
"
test_input
"
);
expect
(
w
.
type
).
toBe
(
"
button
"
);
});
});
tests-ui/tests/groupNode.test.js
0 → 100644
View file @
c92f3dca
This diff is collapsed.
Click to expand it.
tests-ui/tests/widgetInputs.test.js
View file @
c92f3dca
...
...
@@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType
* @param {
boolean } hasC
ontrolWidget
* @param {
number } c
ontrolWidget
Count
* @returns
*/
async
function
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
widgetType
,
hasC
ontrolWidget
)
{
async
function
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
widgetType
,
c
ontrolWidget
Count
=
0
)
{
// Connect to primitive and ensure its still connected after
let
primitive
=
ez
.
PrimitiveNode
();
primitive
.
outputs
[
0
].
connectTo
(
input
);
...
...
@@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
expect
(
valueWidget
.
widget
.
type
).
toBe
(
widgetType
);
// Check if control_after_generate should be added
if
(
hasC
ontrolWidget
)
{
if
(
c
ontrolWidget
Count
)
{
const
controlWidget
=
primitive
.
widgets
.
control_after_generate
;
expect
(
controlWidget
.
widget
.
type
).
toBe
(
"
combo
"
);
if
(
widgetType
===
"
combo
"
)
{
const
filterWidget
=
primitive
.
widgets
.
control_filter_list
;
expect
(
filterWidget
.
widget
.
type
).
toBe
(
"
string
"
);
}
}
// Ensure we dont have other widgets
expect
(
primitive
.
node
.
widgets
).
toHaveLength
(
1
+
+!!
hasC
ontrolWidget
);
expect
(
primitive
.
node
.
widgets
).
toHaveLength
(
1
+
c
ontrolWidget
Count
);
});
return
primitive
;
...
...
@@ -55,8 +59,8 @@ describe("widget inputs", () => {
});
[
{
name
:
"
int
"
,
type
:
"
INT
"
,
widget
:
"
number
"
,
control
:
true
},
{
name
:
"
float
"
,
type
:
"
FLOAT
"
,
widget
:
"
number
"
,
control
:
true
},
{
name
:
"
int
"
,
type
:
"
INT
"
,
widget
:
"
number
"
,
control
:
1
},
{
name
:
"
float
"
,
type
:
"
FLOAT
"
,
widget
:
"
number
"
,
control
:
1
},
{
name
:
"
text
"
,
type
:
"
STRING
"
},
{
name
:
"
customtext
"
,
...
...
@@ -64,7 +68,7 @@ describe("widget inputs", () => {
opt
:
{
multiline
:
true
},
},
{
name
:
"
toggle
"
,
type
:
"
BOOLEAN
"
},
{
name
:
"
combo
"
,
type
:
[
"
a
"
,
"
b
"
,
"
c
"
],
control
:
true
},
{
name
:
"
combo
"
,
type
:
[
"
a
"
,
"
b
"
,
"
c
"
],
control
:
2
},
].
forEach
((
c
)
=>
{
test
(
`widget conversion + primitive works on
${
c
.
name
}
`
,
async
()
=>
{
const
{
ez
,
graph
}
=
await
start
({
...
...
@@ -106,7 +110,7 @@ describe("widget inputs", () => {
n
.
widgets
.
ckpt_name
.
convertToInput
();
expect
(
n
.
inputs
.
length
).
toEqual
(
inputCount
+
1
);
const
primitive
=
await
connectPrimitiveAndReload
(
ez
,
graph
,
n
.
inputs
.
ckpt_name
,
"
combo
"
,
true
);
const
primitive
=
await
connectPrimitiveAndReload
(
ez
,
graph
,
n
.
inputs
.
ckpt_name
,
"
combo
"
,
2
);
// Disconnect & reconnect
primitive
.
outputs
[
0
].
connections
[
0
].
disconnect
();
...
...
@@ -198,8 +202,8 @@ describe("widget inputs", () => {
});
expect
(
dialogShow
).
toBeCalledTimes
(
1
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]).
toContain
(
"
the following node types were not found
"
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]).
toContain
(
"
TestNode
"
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]
.
innerHTML
).
toContain
(
"
the following node types were not found
"
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]
.
innerHTML
).
toContain
(
"
TestNode
"
);
});
test
(
"
defaultInput widgets can be converted back to inputs
"
,
async
()
=>
{
...
...
@@ -226,7 +230,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
if
(
!
assertNotNullOrUndefined
(
input
))
return
;
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
true
);
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
1
);
n
=
graph
.
find
(
n
);
expect
(
n
.
widgets
).
toHaveLength
(
1
);
w
=
n
.
widgets
.
example
;
...
...
@@ -258,7 +262,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
if
(
assertNotNullOrUndefined
(
input
))
{
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
true
);
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
1
);
n
=
graph
.
find
(
n
);
expect
(
n
.
widgets
).
toHaveLength
(
1
);
expect
(
n
.
widgets
.
example
.
isConvertedToInput
).
toBeTruthy
();
...
...
@@ -316,4 +320,76 @@ describe("widget inputs", () => {
n1
.
outputs
[
0
].
connectTo
(
n2
.
inputs
[
0
]);
expect
(()
=>
n1
.
outputs
[
0
].
connectTo
(
n3
.
inputs
[
0
])).
toThrow
();
});
test
(
"
combo primitive can filter list when control_after_generate called
"
,
async
()
=>
{
const
{
ez
}
=
await
start
({
mockNodeDefs
:
{
...
makeNodeDef
(
"
TestNode1
"
,
{
example
:
[[
"
A
"
,
"
B
"
,
"
C
"
,
"
D
"
,
"
AA
"
,
"
BB
"
,
"
CC
"
,
"
DD
"
,
"
AAA
"
,
"
BBB
"
],
{}]
}),
},
});
const
n1
=
ez
.
TestNode1
();
n1
.
widgets
.
example
.
convertToInput
();
const
p
=
ez
.
PrimitiveNode
()
p
.
outputs
[
0
].
connectTo
(
n1
.
inputs
[
0
]);
const
value
=
p
.
widgets
.
value
;
const
control
=
p
.
widgets
.
control_after_generate
.
widget
;
const
filter
=
p
.
widgets
.
control_filter_list
;
expect
(
p
.
widgets
.
length
).
toBe
(
3
);
control
.
value
=
"
increment
"
;
expect
(
value
.
value
).
toBe
(
"
A
"
);
// Manually trigger after queue when set to increment
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
// Filter to items containing D
filter
.
value
=
"
D
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
D
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
DD
"
);
// Check decrement
value
.
value
=
"
BBB
"
;
control
.
value
=
"
decrement
"
;
filter
.
value
=
"
B
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
BB
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
// Check regex works
value
.
value
=
"
BBB
"
;
filter
.
value
=
"
/[AB]|^C$/
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
AAA
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
BB
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
AA
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
C
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
A
"
);
// Check random
control
.
value
=
"
randomize
"
;
filter
.
value
=
"
/D/
"
;
for
(
let
i
=
0
;
i
<
100
;
i
++
)
{
control
[
"
afterQueued
"
]();
expect
(
value
.
value
===
"
D
"
||
value
.
value
===
"
DD
"
).
toBeTruthy
();
}
// Ensure it doesnt apply when fixed
control
.
value
=
"
fixed
"
;
value
.
value
=
"
B
"
;
filter
.
value
=
"
C
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
});
});
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment