Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
c92f3dca
Unverified
Commit
c92f3dca
authored
Dec 02, 2023
by
Jairo Correa
Committed by
GitHub
Dec 02, 2023
Browse files
Merge branch 'master' into image-cache
parents
006b24cc
2995a247
Changes
57
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1766 additions
and
84 deletions
+1766
-84
comfy/supported_models.py
comfy/supported_models.py
+38
-2
comfy/supported_models_base.py
comfy/supported_models_base.py
+7
-1
comfy/taesd/taesd.py
comfy/taesd/taesd.py
+14
-5
comfy/utils.py
comfy/utils.py
+10
-8
comfy_extras/nodes_custom_sampler.py
comfy_extras/nodes_custom_sampler.py
+52
-13
comfy_extras/nodes_images.py
comfy_extras/nodes_images.py
+175
-0
comfy_extras/nodes_latent.py
comfy_extras/nodes_latent.py
+36
-0
comfy_extras/nodes_model_advanced.py
comfy_extras/nodes_model_advanced.py
+48
-11
comfy_extras/nodes_model_downscale.py
comfy_extras/nodes_model_downscale.py
+53
-0
comfy_extras/nodes_video_model.py
comfy_extras/nodes_video_model.py
+89
-0
execution.py
execution.py
+19
-4
folder_paths.py
folder_paths.py
+10
-3
latent_preview.py
latent_preview.py
+1
-4
main.py
main.py
+30
-11
nodes.py
nodes.py
+76
-8
server.py
server.py
+5
-2
tests-ui/setup.js
tests-ui/setup.js
+1
-0
tests-ui/tests/extensions.test.js
tests-ui/tests/extensions.test.js
+196
-0
tests-ui/tests/groupNode.test.js
tests-ui/tests/groupNode.test.js
+818
-0
tests-ui/tests/widgetInputs.test.js
tests-ui/tests/widgetInputs.test.js
+88
-12
No files found.
comfy/supported_models.py
View file @
c92f3dca
...
@@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
...
@@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE):
"model_channels"
:
320
,
"model_channels"
:
320
,
"use_linear_in_transformer"
:
False
,
"use_linear_in_transformer"
:
False
,
"adm_in_channels"
:
None
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
}
unet_extra_config
=
{
unet_extra_config
=
{
...
@@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
...
@@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE):
"model_channels"
:
320
,
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
None
,
"adm_in_channels"
:
None
,
"use_temporal_attention"
:
False
,
}
}
latent_format
=
latent_formats
.
SD15
latent_format
=
latent_formats
.
SD15
...
@@ -69,6 +71,10 @@ class SD20(supported_models_base.BASE):
...
@@ -69,6 +71,10 @@ class SD20(supported_models_base.BASE):
return
model_base
.
ModelType
.
EPS
return
model_base
.
ModelType
.
EPS
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
replace_prefix
=
{}
replace_prefix
[
"conditioner.embedders.0.model."
]
=
"cond_stage_model.model."
#SD2 in sgm format
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.clip_h.transformer.text_model."
,
24
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"cond_stage_model.model."
,
"cond_stage_model.clip_h.transformer.text_model."
,
24
)
return
state_dict
return
state_dict
...
@@ -88,6 +94,7 @@ class SD21UnclipL(SD20):
...
@@ -88,6 +94,7 @@ class SD21UnclipL(SD20):
"model_channels"
:
320
,
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
1536
,
"adm_in_channels"
:
1536
,
"use_temporal_attention"
:
False
,
}
}
clip_vision_prefix
=
"embedder.model.visual."
clip_vision_prefix
=
"embedder.model.visual."
...
@@ -100,6 +107,7 @@ class SD21UnclipH(SD20):
...
@@ -100,6 +107,7 @@ class SD21UnclipH(SD20):
"model_channels"
:
320
,
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"use_linear_in_transformer"
:
True
,
"adm_in_channels"
:
2048
,
"adm_in_channels"
:
2048
,
"use_temporal_attention"
:
False
,
}
}
clip_vision_prefix
=
"embedder.model.visual."
clip_vision_prefix
=
"embedder.model.visual."
...
@@ -112,6 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
...
@@ -112,6 +120,7 @@ class SDXLRefiner(supported_models_base.BASE):
"context_dim"
:
1280
,
"context_dim"
:
1280
,
"adm_in_channels"
:
2560
,
"adm_in_channels"
:
2560
,
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
"use_temporal_attention"
:
False
,
}
}
latent_format
=
latent_formats
.
SDXL
latent_format
=
latent_formats
.
SDXL
...
@@ -148,7 +157,8 @@ class SDXL(supported_models_base.BASE):
...
@@ -148,7 +157,8 @@ class SDXL(supported_models_base.BASE):
"use_linear_in_transformer"
:
True
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"context_dim"
:
2048
,
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
}
latent_format
=
latent_formats
.
SDXL
latent_format
=
latent_formats
.
SDXL
...
@@ -203,8 +213,34 @@ class SSD1B(SDXL):
...
@@ -203,8 +213,34 @@ class SSD1B(SDXL):
"use_linear_in_transformer"
:
True
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"context_dim"
:
2048
,
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
class
SVD_img2vid
(
supported_models_base
.
BASE
):
unet_config
=
{
"model_channels"
:
320
,
"in_channels"
:
8
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
1
,
1
,
1
,
1
,
1
,
1
,
0
,
0
],
"context_dim"
:
1024
,
"adm_in_channels"
:
768
,
"use_temporal_attention"
:
True
,
"use_temporal_resblock"
:
True
}
}
clip_vision_prefix
=
"conditioner.embedders.0.open_clip.model.visual."
latent_format
=
latent_formats
.
SD15
sampling_settings
=
{
"sigma_max"
:
700.0
,
"sigma_min"
:
0.002
}
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
SVD_img2vid
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
return
None
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
models
+=
[
SVD_img2vid
]
comfy/supported_models_base.py
View file @
c92f3dca
...
@@ -19,7 +19,7 @@ class BASE:
...
@@ -19,7 +19,7 @@ class BASE:
clip_prefix
=
[]
clip_prefix
=
[]
clip_vision_prefix
=
None
clip_vision_prefix
=
None
noise_aug_config
=
None
noise_aug_config
=
None
beta_schedule
=
"linear"
sampling_settings
=
{}
latent_format
=
latent_formats
.
LatentFormat
latent_format
=
latent_formats
.
LatentFormat
@
classmethod
@
classmethod
...
@@ -53,6 +53,12 @@ class BASE:
...
@@ -53,6 +53,12 @@ class BASE:
def
process_clip_state_dict
(
self
,
state_dict
):
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
return
state_dict
def
process_unet_state_dict
(
self
,
state_dict
):
return
state_dict
def
process_vae_state_dict
(
self
,
state_dict
):
return
state_dict
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
def
process_clip_state_dict_for_saving
(
self
,
state_dict
):
replace_prefix
=
{
""
:
"cond_stage_model."
}
replace_prefix
=
{
""
:
"cond_stage_model."
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
...
...
comfy/taesd/taesd.py
View file @
c92f3dca
...
@@ -46,15 +46,16 @@ class TAESD(nn.Module):
...
@@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude
=
3
latent_magnitude
=
3
latent_shift
=
0.5
latent_shift
=
0.5
def
__init__
(
self
,
encoder_path
=
"taesd_encoder.pth"
,
decoder_path
=
"taesd_decoder.pth"
):
def
__init__
(
self
,
encoder_path
=
None
,
decoder_path
=
None
):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super
().
__init__
()
super
().
__init__
()
self
.
encoder
=
Encoder
()
self
.
taesd_encoder
=
Encoder
()
self
.
decoder
=
Decoder
()
self
.
taesd_decoder
=
Decoder
()
self
.
vae_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
))
if
encoder_path
is
not
None
:
if
encoder_path
is
not
None
:
self
.
encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
self
.
taesd_
encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
if
decoder_path
is
not
None
:
if
decoder_path
is
not
None
:
self
.
decoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
decoder_path
,
safe_load
=
True
))
self
.
taesd_
decoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
decoder_path
,
safe_load
=
True
))
@
staticmethod
@
staticmethod
def
scale_latents
(
x
):
def
scale_latents
(
x
):
...
@@ -65,3 +66,11 @@ class TAESD(nn.Module):
...
@@ -65,3 +66,11 @@ class TAESD(nn.Module):
def
unscale_latents
(
x
):
def
unscale_latents
(
x
):
"""[0, 1] -> raw latents"""
"""[0, 1] -> raw latents"""
return
x
.
sub
(
TAESD
.
latent_shift
).
mul
(
2
*
TAESD
.
latent_magnitude
)
return
x
.
sub
(
TAESD
.
latent_shift
).
mul
(
2
*
TAESD
.
latent_magnitude
)
def
decode
(
self
,
x
):
x_sample
=
self
.
taesd_decoder
(
x
*
self
.
vae_scale
)
x_sample
=
x_sample
.
sub
(
0.5
).
mul
(
2
)
return
x_sample
def
encode
(
self
,
x
):
return
self
.
taesd_encoder
(
x
*
0.5
+
0.5
)
/
self
.
vae_scale
comfy/utils.py
View file @
c92f3dca
...
@@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
...
@@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
for
name
in
attrs
[:
-
1
]:
for
name
in
attrs
[:
-
1
]:
obj
=
getattr
(
obj
,
name
)
obj
=
getattr
(
obj
,
name
)
prev
=
getattr
(
obj
,
attrs
[
-
1
])
prev
=
getattr
(
obj
,
attrs
[
-
1
])
setattr
(
obj
,
attrs
[
-
1
],
torch
.
nn
.
Parameter
(
value
))
setattr
(
obj
,
attrs
[
-
1
],
torch
.
nn
.
Parameter
(
value
,
requires_grad
=
False
))
del
prev
del
prev
def
copy_to_param
(
obj
,
attr
,
value
):
def
copy_to_param
(
obj
,
attr
,
value
):
...
@@ -307,23 +307,25 @@ def bislerp(samples, width, height):
...
@@ -307,23 +307,25 @@ def bislerp(samples, width, height):
res
[
dot
<
1e-5
-
1
]
=
(
b1
*
(
1.0
-
r
)
+
b2
*
r
)[
dot
<
1e-5
-
1
]
res
[
dot
<
1e-5
-
1
]
=
(
b1
*
(
1.0
-
r
)
+
b2
*
r
)[
dot
<
1e-5
-
1
]
return
res
return
res
def
generate_bilinear_data
(
length_old
,
length_new
):
def
generate_bilinear_data
(
length_old
,
length_new
,
device
):
coords_1
=
torch
.
arange
(
length_old
).
reshape
((
1
,
1
,
1
,
-
1
))
.
to
(
torch
.
float32
)
coords_1
=
torch
.
arange
(
length_old
,
dtype
=
torch
.
float32
,
device
=
device
).
reshape
((
1
,
1
,
1
,
-
1
))
coords_1
=
torch
.
nn
.
functional
.
interpolate
(
coords_1
,
size
=
(
1
,
length_new
),
mode
=
"bilinear"
)
coords_1
=
torch
.
nn
.
functional
.
interpolate
(
coords_1
,
size
=
(
1
,
length_new
),
mode
=
"bilinear"
)
ratios
=
coords_1
-
coords_1
.
floor
()
ratios
=
coords_1
-
coords_1
.
floor
()
coords_1
=
coords_1
.
to
(
torch
.
int64
)
coords_1
=
coords_1
.
to
(
torch
.
int64
)
coords_2
=
torch
.
arange
(
length_old
).
reshape
((
1
,
1
,
1
,
-
1
))
.
to
(
torch
.
float32
)
+
1
coords_2
=
torch
.
arange
(
length_old
,
dtype
=
torch
.
float32
,
device
=
device
).
reshape
((
1
,
1
,
1
,
-
1
))
+
1
coords_2
[:,:,:,
-
1
]
-=
1
coords_2
[:,:,:,
-
1
]
-=
1
coords_2
=
torch
.
nn
.
functional
.
interpolate
(
coords_2
,
size
=
(
1
,
length_new
),
mode
=
"bilinear"
)
coords_2
=
torch
.
nn
.
functional
.
interpolate
(
coords_2
,
size
=
(
1
,
length_new
),
mode
=
"bilinear"
)
coords_2
=
coords_2
.
to
(
torch
.
int64
)
coords_2
=
coords_2
.
to
(
torch
.
int64
)
return
ratios
,
coords_1
,
coords_2
return
ratios
,
coords_1
,
coords_2
orig_dtype
=
samples
.
dtype
samples
=
samples
.
float
()
n
,
c
,
h
,
w
=
samples
.
shape
n
,
c
,
h
,
w
=
samples
.
shape
h_new
,
w_new
=
(
height
,
width
)
h_new
,
w_new
=
(
height
,
width
)
#linear w
#linear w
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
w
,
w_new
)
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
w
,
w_new
,
samples
.
device
)
coords_1
=
coords_1
.
expand
((
n
,
c
,
h
,
-
1
))
coords_1
=
coords_1
.
expand
((
n
,
c
,
h
,
-
1
))
coords_2
=
coords_2
.
expand
((
n
,
c
,
h
,
-
1
))
coords_2
=
coords_2
.
expand
((
n
,
c
,
h
,
-
1
))
ratios
=
ratios
.
expand
((
n
,
1
,
h
,
-
1
))
ratios
=
ratios
.
expand
((
n
,
1
,
h
,
-
1
))
...
@@ -336,7 +338,7 @@ def bislerp(samples, width, height):
...
@@ -336,7 +338,7 @@ def bislerp(samples, width, height):
result
=
result
.
reshape
(
n
,
h
,
w_new
,
c
).
movedim
(
-
1
,
1
)
result
=
result
.
reshape
(
n
,
h
,
w_new
,
c
).
movedim
(
-
1
,
1
)
#linear h
#linear h
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
h
,
h_new
)
ratios
,
coords_1
,
coords_2
=
generate_bilinear_data
(
h
,
h_new
,
samples
.
device
)
coords_1
=
coords_1
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
c
,
-
1
,
w_new
))
coords_1
=
coords_1
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
c
,
-
1
,
w_new
))
coords_2
=
coords_2
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
c
,
-
1
,
w_new
))
coords_2
=
coords_2
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
c
,
-
1
,
w_new
))
ratios
=
ratios
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
1
,
-
1
,
w_new
))
ratios
=
ratios
.
reshape
((
1
,
1
,
-
1
,
1
)).
expand
((
n
,
1
,
-
1
,
w_new
))
...
@@ -347,7 +349,7 @@ def bislerp(samples, width, height):
...
@@ -347,7 +349,7 @@ def bislerp(samples, width, height):
result
=
slerp
(
pass_1
,
pass_2
,
ratios
)
result
=
slerp
(
pass_1
,
pass_2
,
ratios
)
result
=
result
.
reshape
(
n
,
h_new
,
w_new
,
c
).
movedim
(
-
1
,
1
)
result
=
result
.
reshape
(
n
,
h_new
,
w_new
,
c
).
movedim
(
-
1
,
1
)
return
result
return
result
.
to
(
orig_dtype
)
def
lanczos
(
samples
,
width
,
height
):
def
lanczos
(
samples
,
width
,
height
):
images
=
[
Image
.
fromarray
(
np
.
clip
(
255.
*
image
.
movedim
(
0
,
-
1
).
cpu
().
numpy
(),
0
,
255
).
astype
(
np
.
uint8
))
for
image
in
samples
]
images
=
[
Image
.
fromarray
(
np
.
clip
(
255.
*
image
.
movedim
(
0
,
-
1
).
cpu
().
numpy
(),
0
,
255
).
astype
(
np
.
uint8
))
for
image
in
samples
]
...
...
comfy_extras/nodes_custom_sampler.py
View file @
c92f3dca
...
@@ -16,7 +16,7 @@ class BasicScheduler:
...
@@ -16,7 +16,7 @@ class BasicScheduler:
}
}
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
FUNCTION
=
"get_sigmas"
...
@@ -36,7 +36,7 @@ class KarrasScheduler:
...
@@ -36,7 +36,7 @@ class KarrasScheduler:
}
}
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
FUNCTION
=
"get_sigmas"
...
@@ -54,7 +54,7 @@ class ExponentialScheduler:
...
@@ -54,7 +54,7 @@ class ExponentialScheduler:
}
}
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
FUNCTION
=
"get_sigmas"
...
@@ -73,7 +73,7 @@ class PolyexponentialScheduler:
...
@@ -73,7 +73,7 @@ class PolyexponentialScheduler:
}
}
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
FUNCTION
=
"get_sigmas"
...
@@ -81,6 +81,25 @@ class PolyexponentialScheduler:
...
@@ -81,6 +81,25 @@ class PolyexponentialScheduler:
sigmas
=
k_diffusion_sampling
.
get_sigmas_polyexponential
(
n
=
steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
rho
=
rho
)
sigmas
=
k_diffusion_sampling
.
get_sigmas_polyexponential
(
n
=
steps
,
sigma_min
=
sigma_min
,
sigma_max
=
sigma_max
,
rho
=
rho
)
return
(
sigmas
,
)
return
(
sigmas
,
)
class
SDTurboScheduler
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"steps"
:
(
"INT"
,
{
"default"
:
1
,
"min"
:
1
,
"max"
:
10
}),
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling/schedulers"
FUNCTION
=
"get_sigmas"
def
get_sigmas
(
self
,
model
,
steps
):
timesteps
=
torch
.
flip
(
torch
.
arange
(
1
,
11
)
*
100
-
1
,
(
0
,))[:
steps
]
sigmas
=
model
.
model
.
model_sampling
.
sigma
(
timesteps
)
sigmas
=
torch
.
cat
([
sigmas
,
sigmas
.
new_zeros
([
1
])])
return
(
sigmas
,
)
class
VPScheduler
:
class
VPScheduler
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -92,7 +111,7 @@ class VPScheduler:
...
@@ -92,7 +111,7 @@ class VPScheduler:
}
}
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/schedulers
"
FUNCTION
=
"get_sigmas"
FUNCTION
=
"get_sigmas"
...
@@ -109,7 +128,7 @@ class SplitSigmas:
...
@@ -109,7 +128,7 @@ class SplitSigmas:
}
}
}
}
RETURN_TYPES
=
(
"SIGMAS"
,
"SIGMAS"
)
RETURN_TYPES
=
(
"SIGMAS"
,
"SIGMAS"
)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/sigmas
"
FUNCTION
=
"get_sigmas"
FUNCTION
=
"get_sigmas"
...
@@ -118,6 +137,24 @@ class SplitSigmas:
...
@@ -118,6 +137,24 @@ class SplitSigmas:
sigmas2
=
sigmas
[
step
:]
sigmas2
=
sigmas
[
step
:]
return
(
sigmas1
,
sigmas2
)
return
(
sigmas1
,
sigmas2
)
class
FlipSigmas
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"sigmas"
:
(
"SIGMAS"
,
),
}
}
RETURN_TYPES
=
(
"SIGMAS"
,)
CATEGORY
=
"sampling/custom_sampling/sigmas"
FUNCTION
=
"get_sigmas"
def
get_sigmas
(
self
,
sigmas
):
sigmas
=
sigmas
.
flip
(
0
)
if
sigmas
[
0
]
==
0
:
sigmas
[
0
]
=
0.0001
return
(
sigmas
,)
class
KSamplerSelect
:
class
KSamplerSelect
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -126,12 +163,12 @@ class KSamplerSelect:
...
@@ -126,12 +163,12 @@ class KSamplerSelect:
}
}
}
}
RETURN_TYPES
=
(
"SAMPLER"
,)
RETURN_TYPES
=
(
"SAMPLER"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/samplers
"
FUNCTION
=
"get_sampler"
FUNCTION
=
"get_sampler"
def
get_sampler
(
self
,
sampler_name
):
def
get_sampler
(
self
,
sampler_name
):
sampler
=
comfy
.
samplers
.
sampler_
class
(
sampler_name
)
()
sampler
=
comfy
.
samplers
.
sampler_
object
(
sampler_name
)
return
(
sampler
,
)
return
(
sampler
,
)
class
SamplerDPMPP_2M_SDE
:
class
SamplerDPMPP_2M_SDE
:
...
@@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
...
@@ -145,7 +182,7 @@ class SamplerDPMPP_2M_SDE:
}
}
}
}
RETURN_TYPES
=
(
"SAMPLER"
,)
RETURN_TYPES
=
(
"SAMPLER"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/samplers
"
FUNCTION
=
"get_sampler"
FUNCTION
=
"get_sampler"
...
@@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
...
@@ -154,7 +191,7 @@ class SamplerDPMPP_2M_SDE:
sampler_name
=
"dpmpp_2m_sde"
sampler_name
=
"dpmpp_2m_sde"
else
:
else
:
sampler_name
=
"dpmpp_2m_sde_gpu"
sampler_name
=
"dpmpp_2m_sde_gpu"
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"solver_type"
:
solver_type
})
()
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"solver_type"
:
solver_type
})
return
(
sampler
,
)
return
(
sampler
,
)
...
@@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
...
@@ -169,7 +206,7 @@ class SamplerDPMPP_SDE:
}
}
}
}
RETURN_TYPES
=
(
"SAMPLER"
,)
RETURN_TYPES
=
(
"SAMPLER"
,)
CATEGORY
=
"sampling/custom_sampling"
CATEGORY
=
"sampling/custom_sampling
/samplers
"
FUNCTION
=
"get_sampler"
FUNCTION
=
"get_sampler"
...
@@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
...
@@ -178,7 +215,7 @@ class SamplerDPMPP_SDE:
sampler_name
=
"dpmpp_sde"
sampler_name
=
"dpmpp_sde"
else
:
else
:
sampler_name
=
"dpmpp_sde_gpu"
sampler_name
=
"dpmpp_sde_gpu"
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"r"
:
r
})
()
sampler
=
comfy
.
samplers
.
ksampler
(
sampler_name
,
{
"eta"
:
eta
,
"s_noise"
:
s_noise
,
"r"
:
r
})
return
(
sampler
,
)
return
(
sampler
,
)
class
SamplerCustom
:
class
SamplerCustom
:
...
@@ -234,13 +271,15 @@ class SamplerCustom:
...
@@ -234,13 +271,15 @@ class SamplerCustom:
NODE_CLASS_MAPPINGS
=
{
NODE_CLASS_MAPPINGS
=
{
"SamplerCustom"
:
SamplerCustom
,
"SamplerCustom"
:
SamplerCustom
,
"BasicScheduler"
:
BasicScheduler
,
"KarrasScheduler"
:
KarrasScheduler
,
"KarrasScheduler"
:
KarrasScheduler
,
"ExponentialScheduler"
:
ExponentialScheduler
,
"ExponentialScheduler"
:
ExponentialScheduler
,
"PolyexponentialScheduler"
:
PolyexponentialScheduler
,
"PolyexponentialScheduler"
:
PolyexponentialScheduler
,
"VPScheduler"
:
VPScheduler
,
"VPScheduler"
:
VPScheduler
,
"SDTurboScheduler"
:
SDTurboScheduler
,
"KSamplerSelect"
:
KSamplerSelect
,
"KSamplerSelect"
:
KSamplerSelect
,
"SamplerDPMPP_2M_SDE"
:
SamplerDPMPP_2M_SDE
,
"SamplerDPMPP_2M_SDE"
:
SamplerDPMPP_2M_SDE
,
"SamplerDPMPP_SDE"
:
SamplerDPMPP_SDE
,
"SamplerDPMPP_SDE"
:
SamplerDPMPP_SDE
,
"BasicScheduler"
:
BasicScheduler
,
"SplitSigmas"
:
SplitSigmas
,
"SplitSigmas"
:
SplitSigmas
,
"FlipSigmas"
:
FlipSigmas
,
}
}
comfy_extras/nodes_images.py
0 → 100644
View file @
c92f3dca
import
nodes
import
folder_paths
from
comfy.cli_args
import
args
from
PIL
import
Image
from
PIL.PngImagePlugin
import
PngInfo
import
numpy
as
np
import
json
import
os
MAX_RESOLUTION
=
nodes
.
MAX_RESOLUTION
class
ImageCrop
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"image"
:
(
"IMAGE"
,),
"width"
:
(
"INT"
,
{
"default"
:
512
,
"min"
:
1
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
"height"
:
(
"INT"
,
{
"default"
:
512
,
"min"
:
1
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
"x"
:
(
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
"y"
:
(
"INT"
,
{
"default"
:
0
,
"min"
:
0
,
"max"
:
MAX_RESOLUTION
,
"step"
:
1
}),
}}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"crop"
CATEGORY
=
"image/transform"
def
crop
(
self
,
image
,
width
,
height
,
x
,
y
):
x
=
min
(
x
,
image
.
shape
[
2
]
-
1
)
y
=
min
(
y
,
image
.
shape
[
1
]
-
1
)
to_x
=
width
+
x
to_y
=
height
+
y
img
=
image
[:,
y
:
to_y
,
x
:
to_x
,
:]
return
(
img
,)
class
RepeatImageBatch
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"image"
:
(
"IMAGE"
,),
"amount"
:
(
"INT"
,
{
"default"
:
1
,
"min"
:
1
,
"max"
:
64
}),
}}
RETURN_TYPES
=
(
"IMAGE"
,)
FUNCTION
=
"repeat"
CATEGORY
=
"image/batch"
def
repeat
(
self
,
image
,
amount
):
s
=
image
.
repeat
((
amount
,
1
,
1
,
1
))
return
(
s
,)
class
SaveAnimatedWEBP
:
def
__init__
(
self
):
self
.
output_dir
=
folder_paths
.
get_output_directory
()
self
.
type
=
"output"
self
.
prefix_append
=
""
methods
=
{
"default"
:
4
,
"fastest"
:
0
,
"slowest"
:
6
}
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"images"
:
(
"IMAGE"
,
),
"filename_prefix"
:
(
"STRING"
,
{
"default"
:
"ComfyUI"
}),
"fps"
:
(
"FLOAT"
,
{
"default"
:
6.0
,
"min"
:
0.01
,
"max"
:
1000.0
,
"step"
:
0.01
}),
"lossless"
:
(
"BOOLEAN"
,
{
"default"
:
True
}),
"quality"
:
(
"INT"
,
{
"default"
:
80
,
"min"
:
0
,
"max"
:
100
}),
"method"
:
(
list
(
s
.
methods
.
keys
()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden"
:
{
"prompt"
:
"PROMPT"
,
"extra_pnginfo"
:
"EXTRA_PNGINFO"
},
}
RETURN_TYPES
=
()
FUNCTION
=
"save_images"
OUTPUT_NODE
=
True
CATEGORY
=
"_for_testing"
def
save_images
(
self
,
images
,
fps
,
filename_prefix
,
lossless
,
quality
,
method
,
num_frames
=
0
,
prompt
=
None
,
extra_pnginfo
=
None
):
method
=
self
.
methods
.
get
(
method
)
filename_prefix
+=
self
.
prefix_append
full_output_folder
,
filename
,
counter
,
subfolder
,
filename_prefix
=
folder_paths
.
get_save_image_path
(
filename_prefix
,
self
.
output_dir
,
images
[
0
].
shape
[
1
],
images
[
0
].
shape
[
0
])
results
=
list
()
pil_images
=
[]
for
image
in
images
:
i
=
255.
*
image
.
cpu
().
numpy
()
img
=
Image
.
fromarray
(
np
.
clip
(
i
,
0
,
255
).
astype
(
np
.
uint8
))
pil_images
.
append
(
img
)
metadata
=
pil_images
[
0
].
getexif
()
if
not
args
.
disable_metadata
:
if
prompt
is
not
None
:
metadata
[
0x0110
]
=
"prompt:{}"
.
format
(
json
.
dumps
(
prompt
))
if
extra_pnginfo
is
not
None
:
inital_exif
=
0x010f
for
x
in
extra_pnginfo
:
metadata
[
inital_exif
]
=
"{}:{}"
.
format
(
x
,
json
.
dumps
(
extra_pnginfo
[
x
]))
inital_exif
-=
1
if
num_frames
==
0
:
num_frames
=
len
(
pil_images
)
c
=
len
(
pil_images
)
for
i
in
range
(
0
,
c
,
num_frames
):
file
=
f
"
{
filename
}
_
{
counter
:
05
}
_.webp"
pil_images
[
i
].
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
save_all
=
True
,
duration
=
int
(
1000.0
/
fps
),
append_images
=
pil_images
[
i
+
1
:
i
+
num_frames
],
exif
=
metadata
,
lossless
=
lossless
,
quality
=
quality
,
method
=
method
)
results
.
append
({
"filename"
:
file
,
"subfolder"
:
subfolder
,
"type"
:
self
.
type
})
counter
+=
1
animated
=
num_frames
!=
1
return
{
"ui"
:
{
"images"
:
results
,
"animated"
:
(
animated
,)
}
}
class
SaveAnimatedPNG
:
def
__init__
(
self
):
self
.
output_dir
=
folder_paths
.
get_output_directory
()
self
.
type
=
"output"
self
.
prefix_append
=
""
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"images"
:
(
"IMAGE"
,
),
"filename_prefix"
:
(
"STRING"
,
{
"default"
:
"ComfyUI"
}),
"fps"
:
(
"FLOAT"
,
{
"default"
:
6.0
,
"min"
:
0.01
,
"max"
:
1000.0
,
"step"
:
0.01
}),
"compress_level"
:
(
"INT"
,
{
"default"
:
4
,
"min"
:
0
,
"max"
:
9
})
},
"hidden"
:
{
"prompt"
:
"PROMPT"
,
"extra_pnginfo"
:
"EXTRA_PNGINFO"
},
}
RETURN_TYPES
=
()
FUNCTION
=
"save_images"
OUTPUT_NODE
=
True
CATEGORY
=
"_for_testing"
def
save_images
(
self
,
images
,
fps
,
compress_level
,
filename_prefix
=
"ComfyUI"
,
prompt
=
None
,
extra_pnginfo
=
None
):
filename_prefix
+=
self
.
prefix_append
full_output_folder
,
filename
,
counter
,
subfolder
,
filename_prefix
=
folder_paths
.
get_save_image_path
(
filename_prefix
,
self
.
output_dir
,
images
[
0
].
shape
[
1
],
images
[
0
].
shape
[
0
])
results
=
list
()
pil_images
=
[]
for
image
in
images
:
i
=
255.
*
image
.
cpu
().
numpy
()
img
=
Image
.
fromarray
(
np
.
clip
(
i
,
0
,
255
).
astype
(
np
.
uint8
))
pil_images
.
append
(
img
)
metadata
=
None
if
not
args
.
disable_metadata
:
metadata
=
PngInfo
()
if
prompt
is
not
None
:
metadata
.
add
(
b
"comf"
,
"prompt"
.
encode
(
"latin-1"
,
"strict"
)
+
b
"
\0
"
+
json
.
dumps
(
prompt
).
encode
(
"latin-1"
,
"strict"
),
after_idat
=
True
)
if
extra_pnginfo
is
not
None
:
for
x
in
extra_pnginfo
:
metadata
.
add
(
b
"comf"
,
x
.
encode
(
"latin-1"
,
"strict"
)
+
b
"
\0
"
+
json
.
dumps
(
extra_pnginfo
[
x
]).
encode
(
"latin-1"
,
"strict"
),
after_idat
=
True
)
file
=
f
"
{
filename
}
_
{
counter
:
05
}
_.png"
pil_images
[
0
].
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
pnginfo
=
metadata
,
compress_level
=
compress_level
,
save_all
=
True
,
duration
=
int
(
1000.0
/
fps
),
append_images
=
pil_images
[
1
:])
results
.
append
({
"filename"
:
file
,
"subfolder"
:
subfolder
,
"type"
:
self
.
type
})
return
{
"ui"
:
{
"images"
:
results
,
"animated"
:
(
True
,)}
}
NODE_CLASS_MAPPINGS
=
{
"ImageCrop"
:
ImageCrop
,
"RepeatImageBatch"
:
RepeatImageBatch
,
"SaveAnimatedWEBP"
:
SaveAnimatedWEBP
,
"SaveAnimatedPNG"
:
SaveAnimatedPNG
,
}
comfy_extras/nodes_latent.py
View file @
c92f3dca
import
comfy.utils
import
comfy.utils
import
torch
def
reshape_latent_to
(
target_shape
,
latent
):
def
reshape_latent_to
(
target_shape
,
latent
):
if
latent
.
shape
[
1
:]
!=
target_shape
[
1
:]:
if
latent
.
shape
[
1
:]
!=
target_shape
[
1
:]:
...
@@ -67,8 +68,43 @@ class LatentMultiply:
...
@@ -67,8 +68,43 @@ class LatentMultiply:
samples_out
[
"samples"
]
=
s1
*
multiplier
samples_out
[
"samples"
]
=
s1
*
multiplier
return
(
samples_out
,)
return
(
samples_out
,)
class
LatentInterpolate
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"samples1"
:
(
"LATENT"
,),
"samples2"
:
(
"LATENT"
,),
"ratio"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.01
}),
}}
RETURN_TYPES
=
(
"LATENT"
,)
FUNCTION
=
"op"
CATEGORY
=
"latent/advanced"
def
op
(
self
,
samples1
,
samples2
,
ratio
):
samples_out
=
samples1
.
copy
()
s1
=
samples1
[
"samples"
]
s2
=
samples2
[
"samples"
]
s2
=
reshape_latent_to
(
s1
.
shape
,
s2
)
m1
=
torch
.
linalg
.
vector_norm
(
s1
,
dim
=
(
1
))
m2
=
torch
.
linalg
.
vector_norm
(
s2
,
dim
=
(
1
))
s1
=
torch
.
nan_to_num
(
s1
/
m1
)
s2
=
torch
.
nan_to_num
(
s2
/
m2
)
t
=
(
s1
*
ratio
+
s2
*
(
1.0
-
ratio
))
mt
=
torch
.
linalg
.
vector_norm
(
t
,
dim
=
(
1
))
st
=
torch
.
nan_to_num
(
t
/
mt
)
samples_out
[
"samples"
]
=
st
*
(
m1
*
ratio
+
m2
*
(
1.0
-
ratio
))
return
(
samples_out
,)
NODE_CLASS_MAPPINGS
=
{
NODE_CLASS_MAPPINGS
=
{
"LatentAdd"
:
LatentAdd
,
"LatentAdd"
:
LatentAdd
,
"LatentSubtract"
:
LatentSubtract
,
"LatentSubtract"
:
LatentSubtract
,
"LatentMultiply"
:
LatentMultiply
,
"LatentMultiply"
:
LatentMultiply
,
"LatentInterpolate"
:
LatentInterpolate
,
}
}
comfy_extras/nodes_model_advanced.py
View file @
c92f3dca
...
@@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
...
@@ -17,7 +17,9 @@ class LCM(comfy.model_sampling.EPS):
return
c_out
*
x0
+
c_skip
*
model_input
return
c_out
*
x0
+
c_skip
*
model_input
class
ModelSamplingDiscreteLCM
(
torch
.
nn
.
Module
):
class
ModelSamplingDiscreteDistilled
(
torch
.
nn
.
Module
):
original_timesteps
=
50
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
sigma_data
=
1.0
self
.
sigma_data
=
1.0
...
@@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
...
@@ -29,13 +31,12 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
alphas
=
1.0
-
betas
alphas
=
1.0
-
betas
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
alphas_cumprod
=
torch
.
cumprod
(
alphas
,
dim
=
0
)
original_timesteps
=
50
self
.
skip_steps
=
timesteps
//
self
.
original_timesteps
self
.
skip_steps
=
timesteps
//
original_timesteps
alphas_cumprod_valid
=
torch
.
zeros
((
original_timesteps
),
dtype
=
torch
.
float32
)
alphas_cumprod_valid
=
torch
.
zeros
((
self
.
original_timesteps
),
dtype
=
torch
.
float32
)
for
x
in
range
(
original_timesteps
):
for
x
in
range
(
self
.
original_timesteps
):
alphas_cumprod_valid
[
original_timesteps
-
1
-
x
]
=
alphas_cumprod
[
timesteps
-
1
-
x
*
self
.
skip_steps
]
alphas_cumprod_valid
[
self
.
original_timesteps
-
1
-
x
]
=
alphas_cumprod
[
timesteps
-
1
-
x
*
self
.
skip_steps
]
sigmas
=
((
1
-
alphas_cumprod_valid
)
/
alphas_cumprod_valid
)
**
0.5
sigmas
=
((
1
-
alphas_cumprod_valid
)
/
alphas_cumprod_valid
)
**
0.5
self
.
set_sigmas
(
sigmas
)
self
.
set_sigmas
(
sigmas
)
...
@@ -55,18 +56,23 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
...
@@ -55,18 +56,23 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
def
timestep
(
self
,
sigma
):
def
timestep
(
self
,
sigma
):
log_sigma
=
sigma
.
log
()
log_sigma
=
sigma
.
log
()
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
dists
=
log_sigma
.
to
(
self
.
log_sigmas
.
device
)
-
self
.
log_sigmas
[:,
None
]
return
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
*
self
.
skip_steps
+
(
self
.
skip_steps
-
1
)
return
(
dists
.
abs
().
argmin
(
dim
=
0
).
view
(
sigma
.
shape
)
*
self
.
skip_steps
+
(
self
.
skip_steps
-
1
)
).
to
(
sigma
.
device
)
def
sigma
(
self
,
timestep
):
def
sigma
(
self
,
timestep
):
t
=
torch
.
clamp
(((
timestep
-
(
self
.
skip_steps
-
1
))
/
self
.
skip_steps
).
float
(),
min
=
0
,
max
=
(
len
(
self
.
sigmas
)
-
1
))
t
=
torch
.
clamp
(((
timestep
.
float
().
to
(
self
.
log_sigmas
.
device
)
-
(
self
.
skip_steps
-
1
))
/
self
.
skip_steps
).
float
(),
min
=
0
,
max
=
(
len
(
self
.
sigmas
)
-
1
))
low_idx
=
t
.
floor
().
long
()
low_idx
=
t
.
floor
().
long
()
high_idx
=
t
.
ceil
().
long
()
high_idx
=
t
.
ceil
().
long
()
w
=
t
.
frac
()
w
=
t
.
frac
()
log_sigma
=
(
1
-
w
)
*
self
.
log_sigmas
[
low_idx
]
+
w
*
self
.
log_sigmas
[
high_idx
]
log_sigma
=
(
1
-
w
)
*
self
.
log_sigmas
[
low_idx
]
+
w
*
self
.
log_sigmas
[
high_idx
]
return
log_sigma
.
exp
()
return
log_sigma
.
exp
()
.
to
(
timestep
.
device
)
def
percent_to_sigma
(
self
,
percent
):
def
percent_to_sigma
(
self
,
percent
):
return
self
.
sigma
(
torch
.
tensor
(
percent
*
999.0
))
if
percent
<=
0.0
:
return
999999999.9
if
percent
>=
1.0
:
return
0.0
percent
=
1.0
-
percent
return
self
.
sigma
(
torch
.
tensor
(
percent
*
999.0
)).
item
()
def
rescale_zero_terminal_snr_sigmas
(
sigmas
):
def
rescale_zero_terminal_snr_sigmas
(
sigmas
):
...
@@ -111,7 +117,7 @@ class ModelSamplingDiscrete:
...
@@ -111,7 +117,7 @@ class ModelSamplingDiscrete:
sampling_type
=
comfy
.
model_sampling
.
V_PREDICTION
sampling_type
=
comfy
.
model_sampling
.
V_PREDICTION
elif
sampling
==
"lcm"
:
elif
sampling
==
"lcm"
:
sampling_type
=
LCM
sampling_type
=
LCM
sampling_base
=
ModelSamplingDiscrete
LCM
sampling_base
=
ModelSamplingDiscrete
Distilled
class
ModelSamplingAdvanced
(
sampling_base
,
sampling_type
):
class
ModelSamplingAdvanced
(
sampling_base
,
sampling_type
):
pass
pass
...
@@ -123,6 +129,36 @@ class ModelSamplingDiscrete:
...
@@ -123,6 +129,36 @@ class ModelSamplingDiscrete:
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
return
(
m
,
)
class
ModelSamplingContinuousEDM
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"sampling"
:
([
"v_prediction"
,
"eps"
],),
"sigma_max"
:
(
"FLOAT"
,
{
"default"
:
120.0
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
"sigma_min"
:
(
"FLOAT"
,
{
"default"
:
0.002
,
"min"
:
0.0
,
"max"
:
1000.0
,
"step"
:
0.001
,
"round"
:
False
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"advanced/model"
def
patch
(
self
,
model
,
sampling
,
sigma_max
,
sigma_min
):
m
=
model
.
clone
()
if
sampling
==
"eps"
:
sampling_type
=
comfy
.
model_sampling
.
EPS
elif
sampling
==
"v_prediction"
:
sampling_type
=
comfy
.
model_sampling
.
V_PREDICTION
class
ModelSamplingAdvanced
(
comfy
.
model_sampling
.
ModelSamplingContinuousEDM
,
sampling_type
):
pass
model_sampling
=
ModelSamplingAdvanced
()
model_sampling
.
set_sigma_range
(
sigma_min
,
sigma_max
)
m
.
add_object_patch
(
"model_sampling"
,
model_sampling
)
return
(
m
,
)
class
RescaleCFG
:
class
RescaleCFG
:
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -164,5 +200,6 @@ class RescaleCFG:
...
@@ -164,5 +200,6 @@ class RescaleCFG:
NODE_CLASS_MAPPINGS
=
{
NODE_CLASS_MAPPINGS
=
{
"ModelSamplingDiscrete"
:
ModelSamplingDiscrete
,
"ModelSamplingDiscrete"
:
ModelSamplingDiscrete
,
"ModelSamplingContinuousEDM"
:
ModelSamplingContinuousEDM
,
"RescaleCFG"
:
RescaleCFG
,
"RescaleCFG"
:
RescaleCFG
,
}
}
comfy_extras/nodes_model_downscale.py
0 → 100644
View file @
c92f3dca
import
torch
import
comfy.utils
class
PatchModelAddDownscale
:
upscale_methods
=
[
"bicubic"
,
"nearest-exact"
,
"bilinear"
,
"area"
,
"bislerp"
]
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"block_number"
:
(
"INT"
,
{
"default"
:
3
,
"min"
:
1
,
"max"
:
32
,
"step"
:
1
}),
"downscale_factor"
:
(
"FLOAT"
,
{
"default"
:
2.0
,
"min"
:
0.1
,
"max"
:
9.0
,
"step"
:
0.001
}),
"start_percent"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"end_percent"
:
(
"FLOAT"
,
{
"default"
:
0.35
,
"min"
:
0.0
,
"max"
:
1.0
,
"step"
:
0.001
}),
"downscale_after_skip"
:
(
"BOOLEAN"
,
{
"default"
:
True
}),
"downscale_method"
:
(
s
.
upscale_methods
,),
"upscale_method"
:
(
s
.
upscale_methods
,),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"_for_testing"
def
patch
(
self
,
model
,
block_number
,
downscale_factor
,
start_percent
,
end_percent
,
downscale_after_skip
,
downscale_method
,
upscale_method
):
sigma_start
=
model
.
model
.
model_sampling
.
percent_to_sigma
(
start_percent
)
sigma_end
=
model
.
model
.
model_sampling
.
percent_to_sigma
(
end_percent
)
def
input_block_patch
(
h
,
transformer_options
):
if
transformer_options
[
"block"
][
1
]
==
block_number
:
sigma
=
transformer_options
[
"sigmas"
][
0
].
item
()
if
sigma
<=
sigma_start
and
sigma
>=
sigma_end
:
h
=
comfy
.
utils
.
common_upscale
(
h
,
round
(
h
.
shape
[
-
1
]
*
(
1.0
/
downscale_factor
)),
round
(
h
.
shape
[
-
2
]
*
(
1.0
/
downscale_factor
)),
downscale_method
,
"disabled"
)
return
h
def
output_block_patch
(
h
,
hsp
,
transformer_options
):
if
h
.
shape
[
2
]
!=
hsp
.
shape
[
2
]:
h
=
comfy
.
utils
.
common_upscale
(
h
,
hsp
.
shape
[
-
1
],
hsp
.
shape
[
-
2
],
upscale_method
,
"disabled"
)
return
h
,
hsp
m
=
model
.
clone
()
if
downscale_after_skip
:
m
.
set_model_input_block_patch_after_skip
(
input_block_patch
)
else
:
m
.
set_model_input_block_patch
(
input_block_patch
)
m
.
set_model_output_block_patch
(
output_block_patch
)
return
(
m
,
)
NODE_CLASS_MAPPINGS
=
{
"PatchModelAddDownscale"
:
PatchModelAddDownscale
,
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
# Sampling
"PatchModelAddDownscale"
:
"PatchModelAddDownscale (Kohya Deep Shrink)"
,
}
comfy_extras/nodes_video_model.py
0 → 100644
View file @
c92f3dca
import
nodes
import
torch
import
comfy.utils
import
comfy.sd
import
folder_paths
class
ImageOnlyCheckpointLoader
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"ckpt_name"
:
(
folder_paths
.
get_filename_list
(
"checkpoints"
),
),
}}
RETURN_TYPES
=
(
"MODEL"
,
"CLIP_VISION"
,
"VAE"
)
FUNCTION
=
"load_checkpoint"
CATEGORY
=
"loaders/video_models"
def
load_checkpoint
(
self
,
ckpt_name
,
output_vae
=
True
,
output_clip
=
True
):
ckpt_path
=
folder_paths
.
get_full_path
(
"checkpoints"
,
ckpt_name
)
out
=
comfy
.
sd
.
load_checkpoint_guess_config
(
ckpt_path
,
output_vae
=
True
,
output_clip
=
False
,
output_clipvision
=
True
,
embedding_directory
=
folder_paths
.
get_folder_paths
(
"embeddings"
))
return
(
out
[
0
],
out
[
3
],
out
[
2
])
class
SVD_img2vid_Conditioning
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"clip_vision"
:
(
"CLIP_VISION"
,),
"init_image"
:
(
"IMAGE"
,),
"vae"
:
(
"VAE"
,),
"width"
:
(
"INT"
,
{
"default"
:
1024
,
"min"
:
16
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"height"
:
(
"INT"
,
{
"default"
:
576
,
"min"
:
16
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"video_frames"
:
(
"INT"
,
{
"default"
:
14
,
"min"
:
1
,
"max"
:
4096
}),
"motion_bucket_id"
:
(
"INT"
,
{
"default"
:
127
,
"min"
:
1
,
"max"
:
1023
}),
"fps"
:
(
"INT"
,
{
"default"
:
6
,
"min"
:
1
,
"max"
:
1024
}),
"augmentation_level"
:
(
"FLOAT"
,
{
"default"
:
0.0
,
"min"
:
0.0
,
"max"
:
10.0
,
"step"
:
0.01
})
}}
RETURN_TYPES
=
(
"CONDITIONING"
,
"CONDITIONING"
,
"LATENT"
)
RETURN_NAMES
=
(
"positive"
,
"negative"
,
"latent"
)
FUNCTION
=
"encode"
CATEGORY
=
"conditioning/video_models"
def
encode
(
self
,
clip_vision
,
init_image
,
vae
,
width
,
height
,
video_frames
,
motion_bucket_id
,
fps
,
augmentation_level
):
output
=
clip_vision
.
encode_image
(
init_image
)
pooled
=
output
.
image_embeds
.
unsqueeze
(
0
)
pixels
=
comfy
.
utils
.
common_upscale
(
init_image
.
movedim
(
-
1
,
1
),
width
,
height
,
"bilinear"
,
"center"
).
movedim
(
1
,
-
1
)
encode_pixels
=
pixels
[:,:,:,:
3
]
if
augmentation_level
>
0
:
encode_pixels
+=
torch
.
randn_like
(
pixels
)
*
augmentation_level
t
=
vae
.
encode
(
encode_pixels
)
positive
=
[[
pooled
,
{
"motion_bucket_id"
:
motion_bucket_id
,
"fps"
:
fps
,
"augmentation_level"
:
augmentation_level
,
"concat_latent_image"
:
t
}]]
negative
=
[[
torch
.
zeros_like
(
pooled
),
{
"motion_bucket_id"
:
motion_bucket_id
,
"fps"
:
fps
,
"augmentation_level"
:
augmentation_level
,
"concat_latent_image"
:
torch
.
zeros_like
(
t
)}]]
latent
=
torch
.
zeros
([
video_frames
,
4
,
height
//
8
,
width
//
8
])
return
(
positive
,
negative
,
{
"samples"
:
latent
})
class
VideoLinearCFGGuidance
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"min_cfg"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
0.0
,
"max"
:
100.0
,
"step"
:
0.5
,
"round"
:
0.01
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"patch"
CATEGORY
=
"sampling/video_models"
def
patch
(
self
,
model
,
min_cfg
):
def
linear_cfg
(
args
):
cond
=
args
[
"cond"
]
uncond
=
args
[
"uncond"
]
cond_scale
=
args
[
"cond_scale"
]
scale
=
torch
.
linspace
(
min_cfg
,
cond_scale
,
cond
.
shape
[
0
],
device
=
cond
.
device
).
reshape
((
cond
.
shape
[
0
],
1
,
1
,
1
))
return
uncond
+
scale
*
(
cond
-
uncond
)
m
=
model
.
clone
()
m
.
set_model_sampler_cfg_function
(
linear_cfg
)
return
(
m
,
)
NODE_CLASS_MAPPINGS
=
{
"ImageOnlyCheckpointLoader"
:
ImageOnlyCheckpointLoader
,
"SVD_img2vid_Conditioning"
:
SVD_img2vid_Conditioning
,
"VideoLinearCFGGuidance"
:
VideoLinearCFGGuidance
,
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
"ImageOnlyCheckpointLoader"
:
"Image Only Checkpoint Loader (img2vid model)"
,
}
execution.py
View file @
c92f3dca
...
@@ -681,6 +681,7 @@ def validate_prompt(prompt):
...
@@ -681,6 +681,7 @@ def validate_prompt(prompt):
return
(
True
,
None
,
list
(
good_outputs
),
node_errors
)
return
(
True
,
None
,
list
(
good_outputs
),
node_errors
)
MAXIMUM_HISTORY_SIZE
=
10000
class
PromptQueue
:
class
PromptQueue
:
def
__init__
(
self
,
server
):
def
__init__
(
self
,
server
):
...
@@ -699,10 +700,12 @@ class PromptQueue:
...
@@ -699,10 +700,12 @@ class PromptQueue:
self
.
server
.
queue_updated
()
self
.
server
.
queue_updated
()
self
.
not_empty
.
notify
()
self
.
not_empty
.
notify
()
def
get
(
self
):
def
get
(
self
,
timeout
=
None
):
with
self
.
not_empty
:
with
self
.
not_empty
:
while
len
(
self
.
queue
)
==
0
:
while
len
(
self
.
queue
)
==
0
:
self
.
not_empty
.
wait
()
self
.
not_empty
.
wait
(
timeout
=
timeout
)
if
timeout
is
not
None
and
len
(
self
.
queue
)
==
0
:
return
None
item
=
heapq
.
heappop
(
self
.
queue
)
item
=
heapq
.
heappop
(
self
.
queue
)
i
=
self
.
task_counter
i
=
self
.
task_counter
self
.
currently_running
[
i
]
=
copy
.
deepcopy
(
item
)
self
.
currently_running
[
i
]
=
copy
.
deepcopy
(
item
)
...
@@ -713,6 +716,8 @@ class PromptQueue:
...
@@ -713,6 +716,8 @@ class PromptQueue:
def
task_done
(
self
,
item_id
,
outputs
):
def
task_done
(
self
,
item_id
,
outputs
):
with
self
.
mutex
:
with
self
.
mutex
:
prompt
=
self
.
currently_running
.
pop
(
item_id
)
prompt
=
self
.
currently_running
.
pop
(
item_id
)
if
len
(
self
.
history
)
>
MAXIMUM_HISTORY_SIZE
:
self
.
history
.
pop
(
next
(
iter
(
self
.
history
)))
self
.
history
[
prompt
[
1
]]
=
{
"prompt"
:
prompt
,
"outputs"
:
{}
}
self
.
history
[
prompt
[
1
]]
=
{
"prompt"
:
prompt
,
"outputs"
:
{}
}
for
o
in
outputs
:
for
o
in
outputs
:
self
.
history
[
prompt
[
1
]][
"outputs"
][
o
]
=
outputs
[
o
]
self
.
history
[
prompt
[
1
]][
"outputs"
][
o
]
=
outputs
[
o
]
...
@@ -747,10 +752,20 @@ class PromptQueue:
...
@@ -747,10 +752,20 @@ class PromptQueue:
return
True
return
True
return
False
return
False
def
get_history
(
self
,
prompt_id
=
None
):
def
get_history
(
self
,
prompt_id
=
None
,
max_items
=
None
,
offset
=-
1
):
with
self
.
mutex
:
with
self
.
mutex
:
if
prompt_id
is
None
:
if
prompt_id
is
None
:
return
copy
.
deepcopy
(
self
.
history
)
out
=
{}
i
=
0
if
offset
<
0
and
max_items
is
not
None
:
offset
=
len
(
self
.
history
)
-
max_items
for
k
in
self
.
history
:
if
i
>=
offset
:
out
[
k
]
=
self
.
history
[
k
]
if
max_items
is
not
None
and
len
(
out
)
>=
max_items
:
break
i
+=
1
return
out
elif
prompt_id
in
self
.
history
:
elif
prompt_id
in
self
.
history
:
return
{
prompt_id
:
copy
.
deepcopy
(
self
.
history
[
prompt_id
])}
return
{
prompt_id
:
copy
.
deepcopy
(
self
.
history
[
prompt_id
])}
else
:
else
:
...
...
folder_paths.py
View file @
c92f3dca
...
@@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
...
@@ -38,7 +38,10 @@ input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "inp
filename_list_cache
=
{}
filename_list_cache
=
{}
if
not
os
.
path
.
exists
(
input_directory
):
if
not
os
.
path
.
exists
(
input_directory
):
os
.
makedirs
(
input_directory
)
try
:
os
.
makedirs
(
input_directory
)
except
:
print
(
"Failed to create input directory"
)
def
set_output_directory
(
output_dir
):
def
set_output_directory
(
output_dir
):
global
output_directory
global
output_directory
...
@@ -228,8 +231,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
...
@@ -228,8 +231,12 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder
=
os
.
path
.
join
(
output_dir
,
subfolder
)
full_output_folder
=
os
.
path
.
join
(
output_dir
,
subfolder
)
if
os
.
path
.
commonpath
((
output_dir
,
os
.
path
.
abspath
(
full_output_folder
)))
!=
output_dir
:
if
os
.
path
.
commonpath
((
output_dir
,
os
.
path
.
abspath
(
full_output_folder
)))
!=
output_dir
:
print
(
"Saving image outside the output folder is not allowed."
)
err
=
"**** ERROR: Saving image outside the output folder is not allowed."
+
\
return
{}
"
\n
full_output_folder: "
+
os
.
path
.
abspath
(
full_output_folder
)
+
\
"
\n
output_dir: "
+
output_dir
+
\
"
\n
commonpath: "
+
os
.
path
.
commonpath
((
output_dir
,
os
.
path
.
abspath
(
full_output_folder
)))
print
(
err
)
raise
Exception
(
err
)
try
:
try
:
counter
=
max
(
filter
(
lambda
a
:
a
[
1
][:
-
1
]
==
filename
and
a
[
1
][
-
1
]
==
"_"
,
map
(
map_filename
,
os
.
listdir
(
full_output_folder
))))[
0
]
+
1
counter
=
max
(
filter
(
lambda
a
:
a
[
1
][:
-
1
]
==
filename
and
a
[
1
][
-
1
]
==
"_"
,
map
(
map_filename
,
os
.
listdir
(
full_output_folder
))))[
0
]
+
1
...
...
latent_preview.py
View file @
c92f3dca
...
@@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
...
@@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self
.
taesd
=
taesd
self
.
taesd
=
taesd
def
decode_latent_to_preview
(
self
,
x0
):
def
decode_latent_to_preview
(
self
,
x0
):
x_sample
=
self
.
taesd
.
decoder
(
x0
[:
1
])[
0
].
detach
()
x_sample
=
self
.
taesd
.
decode
(
x0
[:
1
])[
0
].
detach
()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample
=
x_sample
.
sub
(
0.5
).
mul
(
2
)
x_sample
=
torch
.
clamp
((
x_sample
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_sample
=
torch
.
clamp
((
x_sample
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
x_sample
=
255.
*
np
.
moveaxis
(
x_sample
.
cpu
().
numpy
(),
0
,
2
)
x_sample
=
255.
*
np
.
moveaxis
(
x_sample
.
cpu
().
numpy
(),
0
,
2
)
x_sample
=
x_sample
.
astype
(
np
.
uint8
)
x_sample
=
x_sample
.
astype
(
np
.
uint8
)
...
...
main.py
View file @
c92f3dca
...
@@ -88,18 +88,37 @@ def cuda_malloc_warning():
...
@@ -88,18 +88,37 @@ def cuda_malloc_warning():
def
prompt_worker
(
q
,
server
):
def
prompt_worker
(
q
,
server
):
e
=
execution
.
PromptExecutor
(
server
)
e
=
execution
.
PromptExecutor
(
server
)
last_gc_collect
=
0
need_gc
=
False
gc_collect_interval
=
10.0
while
True
:
while
True
:
item
,
item_id
=
q
.
get
()
timeout
=
None
execution_start_time
=
time
.
perf_counter
()
if
need_gc
:
prompt_id
=
item
[
1
]
timeout
=
max
(
gc_collect_interval
-
(
current_time
-
last_gc_collect
),
0.0
)
e
.
execute
(
item
[
2
],
prompt_id
,
item
[
3
],
item
[
4
])
q
.
task_done
(
item_id
,
e
.
outputs_ui
)
queue_item
=
q
.
get
(
timeout
=
timeout
)
if
server
.
client_id
is
not
None
:
if
queue_item
is
not
None
:
server
.
send_sync
(
"executing"
,
{
"node"
:
None
,
"prompt_id"
:
prompt_id
},
server
.
client_id
)
item
,
item_id
=
queue_item
execution_start_time
=
time
.
perf_counter
()
print
(
"Prompt executed in {:.2f} seconds"
.
format
(
time
.
perf_counter
()
-
execution_start_time
))
prompt_id
=
item
[
1
]
gc
.
collect
()
e
.
execute
(
item
[
2
],
prompt_id
,
item
[
3
],
item
[
4
])
comfy
.
model_management
.
soft_empty_cache
()
need_gc
=
True
q
.
task_done
(
item_id
,
e
.
outputs_ui
)
if
server
.
client_id
is
not
None
:
server
.
send_sync
(
"executing"
,
{
"node"
:
None
,
"prompt_id"
:
prompt_id
},
server
.
client_id
)
current_time
=
time
.
perf_counter
()
execution_time
=
current_time
-
execution_start_time
print
(
"Prompt executed in {:.2f} seconds"
.
format
(
execution_time
))
if
need_gc
:
current_time
=
time
.
perf_counter
()
if
(
current_time
-
last_gc_collect
)
>
gc_collect_interval
:
gc
.
collect
()
comfy
.
model_management
.
soft_empty_cache
()
last_gc_collect
=
current_time
need_gc
=
False
async
def
run
(
server
,
address
=
''
,
port
=
8188
,
verbose
=
True
,
call_on_start
=
None
):
async
def
run
(
server
,
address
=
''
,
port
=
8188
,
verbose
=
True
,
call_on_start
=
None
):
await
asyncio
.
gather
(
server
.
start
(
address
,
port
,
verbose
,
call_on_start
),
server
.
publish_loop
())
await
asyncio
.
gather
(
server
.
start
(
address
,
port
,
verbose
,
call_on_start
),
server
.
publish_loop
())
...
...
nodes.py
View file @
c92f3dca
...
@@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
...
@@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
c
=
[]
c
=
[]
for
t
in
conditioning
:
for
t
in
conditioning
:
d
=
t
[
1
].
copy
()
d
=
t
[
1
].
copy
()
d
[
'start_percent'
]
=
1.0
-
start
d
[
'start_percent'
]
=
start
d
[
'end_percent'
]
=
1.0
-
end
d
[
'end_percent'
]
=
end
n
=
[
t
[
0
],
d
]
n
=
[
t
[
0
],
d
]
c
.
append
(
n
)
c
.
append
(
n
)
return
(
c
,
)
return
(
c
,
)
...
@@ -572,10 +572,69 @@ class LoraLoader:
...
@@ -572,10 +572,69 @@ class LoraLoader:
model_lora
,
clip_lora
=
comfy
.
sd
.
load_lora_for_models
(
model
,
clip
,
lora
,
strength_model
,
strength_clip
)
model_lora
,
clip_lora
=
comfy
.
sd
.
load_lora_for_models
(
model
,
clip
,
lora
,
strength_model
,
strength_clip
)
return
(
model_lora
,
clip_lora
)
return
(
model_lora
,
clip_lora
)
class
LoraLoaderModelOnly
(
LoraLoader
):
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"model"
:
(
"MODEL"
,),
"lora_name"
:
(
folder_paths
.
get_filename_list
(
"loras"
),
),
"strength_model"
:
(
"FLOAT"
,
{
"default"
:
1.0
,
"min"
:
-
20.0
,
"max"
:
20.0
,
"step"
:
0.01
}),
}}
RETURN_TYPES
=
(
"MODEL"
,)
FUNCTION
=
"load_lora_model_only"
def
load_lora_model_only
(
self
,
model
,
lora_name
,
strength_model
):
return
(
self
.
load_lora
(
model
,
None
,
lora_name
,
strength_model
,
0
)[
0
],)
class
VAELoader
:
class
VAELoader
:
@
staticmethod
def
vae_list
():
vaes
=
folder_paths
.
get_filename_list
(
"vae"
)
approx_vaes
=
folder_paths
.
get_filename_list
(
"vae_approx"
)
sdxl_taesd_enc
=
False
sdxl_taesd_dec
=
False
sd1_taesd_enc
=
False
sd1_taesd_dec
=
False
for
v
in
approx_vaes
:
if
v
.
startswith
(
"taesd_decoder."
):
sd1_taesd_dec
=
True
elif
v
.
startswith
(
"taesd_encoder."
):
sd1_taesd_enc
=
True
elif
v
.
startswith
(
"taesdxl_decoder."
):
sdxl_taesd_dec
=
True
elif
v
.
startswith
(
"taesdxl_encoder."
):
sdxl_taesd_enc
=
True
if
sd1_taesd_dec
and
sd1_taesd_enc
:
vaes
.
append
(
"taesd"
)
if
sdxl_taesd_dec
and
sdxl_taesd_enc
:
vaes
.
append
(
"taesdxl"
)
return
vaes
@
staticmethod
def
load_taesd
(
name
):
sd
=
{}
approx_vaes
=
folder_paths
.
get_filename_list
(
"vae_approx"
)
encoder
=
next
(
filter
(
lambda
a
:
a
.
startswith
(
"{}_encoder."
.
format
(
name
)),
approx_vaes
))
decoder
=
next
(
filter
(
lambda
a
:
a
.
startswith
(
"{}_decoder."
.
format
(
name
)),
approx_vaes
))
enc
=
comfy
.
utils
.
load_torch_file
(
folder_paths
.
get_full_path
(
"vae_approx"
,
encoder
))
for
k
in
enc
:
sd
[
"taesd_encoder.{}"
.
format
(
k
)]
=
enc
[
k
]
dec
=
comfy
.
utils
.
load_torch_file
(
folder_paths
.
get_full_path
(
"vae_approx"
,
decoder
))
for
k
in
dec
:
sd
[
"taesd_decoder.{}"
.
format
(
k
)]
=
dec
[
k
]
if
name
==
"taesd"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.18215
)
elif
name
==
"taesdxl"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.13025
)
return
sd
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"vae_name"
:
(
folder_paths
.
get_filenam
e_list
(
"vae"
),
)}}
return
{
"required"
:
{
"vae_name"
:
(
s
.
va
e_list
(),
)}}
RETURN_TYPES
=
(
"VAE"
,)
RETURN_TYPES
=
(
"VAE"
,)
FUNCTION
=
"load_vae"
FUNCTION
=
"load_vae"
...
@@ -583,8 +642,11 @@ class VAELoader:
...
@@ -583,8 +642,11 @@ class VAELoader:
#TODO: scale factor?
#TODO: scale factor?
def
load_vae
(
self
,
vae_name
):
def
load_vae
(
self
,
vae_name
):
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
if
vae_name
in
[
"taesd"
,
"taesdxl"
]:
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
sd
=
self
.
load_taesd
(
vae_name
)
else
:
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
)
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
)
return
(
vae
,)
return
(
vae
,)
...
@@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
...
@@ -685,7 +747,7 @@ class ControlNetApplyAdvanced:
if
prev_cnet
in
cnets
:
if
prev_cnet
in
cnets
:
c_net
=
cnets
[
prev_cnet
]
c_net
=
cnets
[
prev_cnet
]
else
:
else
:
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
1.0
-
start_percent
,
1.0
-
end_percent
))
c_net
=
control_net
.
copy
().
set_cond_hint
(
control_hint
,
strength
,
(
start_percent
,
end_percent
))
c_net
.
set_previous_controlnet
(
prev_cnet
)
c_net
.
set_previous_controlnet
(
prev_cnet
)
cnets
[
prev_cnet
]
=
c_net
cnets
[
prev_cnet
]
=
c_net
...
@@ -1275,6 +1337,7 @@ class SaveImage:
...
@@ -1275,6 +1337,7 @@ class SaveImage:
self
.
output_dir
=
folder_paths
.
get_output_directory
()
self
.
output_dir
=
folder_paths
.
get_output_directory
()
self
.
type
=
"output"
self
.
type
=
"output"
self
.
prefix_append
=
""
self
.
prefix_append
=
""
self
.
compress_level
=
4
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -1308,7 +1371,7 @@ class SaveImage:
...
@@ -1308,7 +1371,7 @@ class SaveImage:
metadata
.
add_text
(
x
,
json
.
dumps
(
extra_pnginfo
[
x
]))
metadata
.
add_text
(
x
,
json
.
dumps
(
extra_pnginfo
[
x
]))
file
=
f
"
{
filename
}
_
{
counter
:
05
}
_.png"
file
=
f
"
{
filename
}
_
{
counter
:
05
}
_.png"
img
.
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
pnginfo
=
metadata
,
compress_level
=
4
)
img
.
save
(
os
.
path
.
join
(
full_output_folder
,
file
),
pnginfo
=
metadata
,
compress_level
=
self
.
compress_level
)
results
.
append
({
results
.
append
({
"filename"
:
file
,
"filename"
:
file
,
"subfolder"
:
subfolder
,
"subfolder"
:
subfolder
,
...
@@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
...
@@ -1323,6 +1386,7 @@ class PreviewImage(SaveImage):
self
.
output_dir
=
folder_paths
.
get_temp_directory
()
self
.
output_dir
=
folder_paths
.
get_temp_directory
()
self
.
type
=
"temp"
self
.
type
=
"temp"
self
.
prefix_append
=
"_temp_"
+
''
.
join
(
random
.
choice
(
"abcdefghijklmnopqrstupvxyz"
)
for
x
in
range
(
5
))
self
.
prefix_append
=
"_temp_"
+
''
.
join
(
random
.
choice
(
"abcdefghijklmnopqrstupvxyz"
)
for
x
in
range
(
5
))
self
.
compress_level
=
1
@
classmethod
@
classmethod
def
INPUT_TYPES
(
s
):
def
INPUT_TYPES
(
s
):
...
@@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
...
@@ -1654,6 +1718,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningZeroOut"
:
ConditioningZeroOut
,
"ConditioningZeroOut"
:
ConditioningZeroOut
,
"ConditioningSetTimestepRange"
:
ConditioningSetTimestepRange
,
"ConditioningSetTimestepRange"
:
ConditioningSetTimestepRange
,
"LoraLoaderModelOnly"
:
LoraLoaderModelOnly
,
}
}
NODE_DISPLAY_NAME_MAPPINGS
=
{
NODE_DISPLAY_NAME_MAPPINGS
=
{
...
@@ -1759,7 +1824,7 @@ def load_custom_nodes():
...
@@ -1759,7 +1824,7 @@ def load_custom_nodes():
node_paths
=
folder_paths
.
get_folder_paths
(
"custom_nodes"
)
node_paths
=
folder_paths
.
get_folder_paths
(
"custom_nodes"
)
node_import_times
=
[]
node_import_times
=
[]
for
custom_node_path
in
node_paths
:
for
custom_node_path
in
node_paths
:
possible_modules
=
os
.
listdir
(
custom_node_path
)
possible_modules
=
os
.
listdir
(
os
.
path
.
realpath
(
custom_node_path
)
)
if
"__pycache__"
in
possible_modules
:
if
"__pycache__"
in
possible_modules
:
possible_modules
.
remove
(
"__pycache__"
)
possible_modules
.
remove
(
"__pycache__"
)
...
@@ -1799,6 +1864,9 @@ def init_custom_nodes():
...
@@ -1799,6 +1864,9 @@ def init_custom_nodes():
"nodes_custom_sampler.py"
,
"nodes_custom_sampler.py"
,
"nodes_hypertile.py"
,
"nodes_hypertile.py"
,
"nodes_model_advanced.py"
,
"nodes_model_advanced.py"
,
"nodes_model_downscale.py"
,
"nodes_images.py"
,
"nodes_video_model.py"
,
]
]
for
node_file
in
extras_files
:
for
node_file
in
extras_files
:
...
...
server.py
View file @
c92f3dca
...
@@ -431,7 +431,10 @@ class PromptServer():
...
@@ -431,7 +431,10 @@ class PromptServer():
@
routes
.
get
(
"/history"
)
@
routes
.
get
(
"/history"
)
async
def
get_history
(
request
):
async
def
get_history
(
request
):
return
web
.
json_response
(
self
.
prompt_queue
.
get_history
())
max_items
=
request
.
rel_url
.
query
.
get
(
"max_items"
,
None
)
if
max_items
is
not
None
:
max_items
=
int
(
max_items
)
return
web
.
json_response
(
self
.
prompt_queue
.
get_history
(
max_items
=
max_items
))
@
routes
.
get
(
"/history/{prompt_id}"
)
@
routes
.
get
(
"/history/{prompt_id}"
)
async
def
get_history
(
request
):
async
def
get_history
(
request
):
...
@@ -573,7 +576,7 @@ class PromptServer():
...
@@ -573,7 +576,7 @@ class PromptServer():
bytesIO
=
BytesIO
()
bytesIO
=
BytesIO
()
header
=
struct
.
pack
(
">I"
,
type_num
)
header
=
struct
.
pack
(
">I"
,
type_num
)
bytesIO
.
write
(
header
)
bytesIO
.
write
(
header
)
image
.
save
(
bytesIO
,
format
=
image_type
,
quality
=
95
,
compress_level
=
4
)
image
.
save
(
bytesIO
,
format
=
image_type
,
quality
=
95
,
compress_level
=
1
)
preview_bytes
=
bytesIO
.
getvalue
()
preview_bytes
=
bytesIO
.
getvalue
()
await
self
.
send_bytes
(
BinaryEventTypes
.
PREVIEW_IMAGE
,
preview_bytes
,
sid
=
sid
)
await
self
.
send_bytes
(
BinaryEventTypes
.
PREVIEW_IMAGE
,
preview_bytes
,
sid
=
sid
)
...
...
tests-ui/setup.js
View file @
c92f3dca
...
@@ -20,6 +20,7 @@ async function setup() {
...
@@ -20,6 +20,7 @@ async function setup() {
// Modify the response data to add some checkpoints
// Modify the response data to add some checkpoints
const
objectInfo
=
JSON
.
parse
(
data
);
const
objectInfo
=
JSON
.
parse
(
data
);
objectInfo
.
CheckpointLoaderSimple
.
input
.
required
.
ckpt_name
[
0
]
=
[
"
model1.safetensors
"
,
"
model2.ckpt
"
];
objectInfo
.
CheckpointLoaderSimple
.
input
.
required
.
ckpt_name
[
0
]
=
[
"
model1.safetensors
"
,
"
model2.ckpt
"
];
objectInfo
.
VAELoader
.
input
.
required
.
vae_name
[
0
]
=
[
"
vae1.safetensors
"
,
"
vae2.ckpt
"
];
data
=
JSON
.
stringify
(
objectInfo
,
undefined
,
"
\t
"
);
data
=
JSON
.
stringify
(
objectInfo
,
undefined
,
"
\t
"
);
...
...
tests-ui/tests/extensions.test.js
0 → 100644
View file @
c92f3dca
// @ts-check
/// <reference path="../node_modules/@types/jest/index.d.ts" />
const
{
start
}
=
require
(
"
../utils
"
);
const
lg
=
require
(
"
../utils/litegraph
"
);
describe
(
"
extensions
"
,
()
=>
{
beforeEach
(()
=>
{
lg
.
setup
(
global
);
});
afterEach
(()
=>
{
lg
.
teardown
(
global
);
});
it
(
"
calls each extension hook
"
,
async
()
=>
{
const
mockExtension
=
{
name
:
"
TestExtension
"
,
init
:
jest
.
fn
(),
setup
:
jest
.
fn
(),
addCustomNodeDefs
:
jest
.
fn
(),
getCustomWidgets
:
jest
.
fn
(),
beforeRegisterNodeDef
:
jest
.
fn
(),
registerCustomNodes
:
jest
.
fn
(),
loadedGraphNode
:
jest
.
fn
(),
nodeCreated
:
jest
.
fn
(),
beforeConfigureGraph
:
jest
.
fn
(),
afterConfigureGraph
:
jest
.
fn
(),
};
const
{
app
,
ez
,
graph
}
=
await
start
({
async
preSetup
(
app
)
{
app
.
registerExtension
(
mockExtension
);
},
});
// Basic initialisation hooks should be called once, with app
expect
(
mockExtension
.
init
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
init
).
toHaveBeenCalledWith
(
app
);
// Adding custom node defs should be passed the full list of nodes
expect
(
mockExtension
.
addCustomNodeDefs
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
addCustomNodeDefs
.
mock
.
calls
[
0
][
1
]).
toStrictEqual
(
app
);
const
defs
=
mockExtension
.
addCustomNodeDefs
.
mock
.
calls
[
0
][
0
];
expect
(
defs
).
toHaveProperty
(
"
KSampler
"
);
expect
(
defs
).
toHaveProperty
(
"
LoadImage
"
);
// Get custom widgets is called once and should return new widget types
expect
(
mockExtension
.
getCustomWidgets
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
getCustomWidgets
).
toHaveBeenCalledWith
(
app
);
// Before register node def will be called once per node type
const
nodeNames
=
Object
.
keys
(
defs
);
const
nodeCount
=
nodeNames
.
length
;
expect
(
mockExtension
.
beforeRegisterNodeDef
).
toHaveBeenCalledTimes
(
nodeCount
);
for
(
let
i
=
0
;
i
<
nodeCount
;
i
++
)
{
// It should be send the JS class and the original JSON definition
const
nodeClass
=
mockExtension
.
beforeRegisterNodeDef
.
mock
.
calls
[
i
][
0
];
const
nodeDef
=
mockExtension
.
beforeRegisterNodeDef
.
mock
.
calls
[
i
][
1
];
expect
(
nodeClass
.
name
).
toBe
(
"
ComfyNode
"
);
expect
(
nodeClass
.
comfyClass
).
toBe
(
nodeNames
[
i
]);
expect
(
nodeDef
.
name
).
toBe
(
nodeNames
[
i
]);
expect
(
nodeDef
).
toHaveProperty
(
"
input
"
);
expect
(
nodeDef
).
toHaveProperty
(
"
output
"
);
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect
(
mockExtension
.
registerCustomNodes
).
toHaveBeenCalledTimes
(
1
);
// Before configure graph will be called here as the default graph is being loaded
expect
(
mockExtension
.
beforeConfigureGraph
).
toHaveBeenCalledTimes
(
1
);
// it gets sent the graph data that is going to be loaded
const
graphData
=
mockExtension
.
beforeConfigureGraph
.
mock
.
calls
[
0
][
0
];
// A node created is fired for each node constructor that is called
expect
(
mockExtension
.
nodeCreated
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
);
for
(
let
i
=
0
;
i
<
graphData
.
nodes
.
length
;
i
++
)
{
expect
(
mockExtension
.
nodeCreated
.
mock
.
calls
[
i
][
0
].
type
).
toBe
(
graphData
.
nodes
[
i
].
type
);
}
// Each node then calls loadedGraphNode to allow them to be updated
expect
(
mockExtension
.
loadedGraphNode
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
);
for
(
let
i
=
0
;
i
<
graphData
.
nodes
.
length
;
i
++
)
{
expect
(
mockExtension
.
loadedGraphNode
.
mock
.
calls
[
i
][
0
].
type
).
toBe
(
graphData
.
nodes
[
i
].
type
);
}
// After configure is then called once all the setup is done
expect
(
mockExtension
.
afterConfigureGraph
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
setup
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
setup
).
toHaveBeenCalledWith
(
app
);
// Ensure hooks are called in the correct order
const
callOrder
=
[
"
init
"
,
"
addCustomNodeDefs
"
,
"
getCustomWidgets
"
,
"
beforeRegisterNodeDef
"
,
"
registerCustomNodes
"
,
"
beforeConfigureGraph
"
,
"
nodeCreated
"
,
"
loadedGraphNode
"
,
"
afterConfigureGraph
"
,
"
setup
"
,
];
for
(
let
i
=
1
;
i
<
callOrder
.
length
;
i
++
)
{
const
fn1
=
mockExtension
[
callOrder
[
i
-
1
]];
const
fn2
=
mockExtension
[
callOrder
[
i
]];
expect
(
fn1
.
mock
.
invocationCallOrder
[
0
]).
toBeLessThan
(
fn2
.
mock
.
invocationCallOrder
[
0
]);
}
graph
.
clear
();
// Ensure adding a new node calls the correct callback
ez
.
LoadImage
();
expect
(
mockExtension
.
loadedGraphNode
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
);
expect
(
mockExtension
.
nodeCreated
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
+
1
);
expect
(
mockExtension
.
nodeCreated
.
mock
.
lastCall
[
0
].
type
).
toBe
(
"
LoadImage
"
);
// Reload the graph to ensure correct hooks are fired
await
graph
.
reload
();
// These hooks should not be fired again
expect
(
mockExtension
.
init
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
addCustomNodeDefs
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
getCustomWidgets
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
registerCustomNodes
).
toHaveBeenCalledTimes
(
1
);
expect
(
mockExtension
.
beforeRegisterNodeDef
).
toHaveBeenCalledTimes
(
nodeCount
);
expect
(
mockExtension
.
setup
).
toHaveBeenCalledTimes
(
1
);
// These should be called again
expect
(
mockExtension
.
beforeConfigureGraph
).
toHaveBeenCalledTimes
(
2
);
expect
(
mockExtension
.
nodeCreated
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
+
2
);
expect
(
mockExtension
.
loadedGraphNode
).
toHaveBeenCalledTimes
(
graphData
.
nodes
.
length
+
1
);
expect
(
mockExtension
.
afterConfigureGraph
).
toHaveBeenCalledTimes
(
2
);
});
it
(
"
allows custom nodeDefs and widgets to be registered
"
,
async
()
=>
{
const
widgetMock
=
jest
.
fn
((
node
,
inputName
,
inputData
,
app
)
=>
{
expect
(
node
.
constructor
.
comfyClass
).
toBe
(
"
TestNode
"
);
expect
(
inputName
).
toBe
(
"
test_input
"
);
expect
(
inputData
[
0
]).
toBe
(
"
CUSTOMWIDGET
"
);
expect
(
inputData
[
1
]?.
hello
).
toBe
(
"
world
"
);
expect
(
app
).
toStrictEqual
(
app
);
return
{
widget
:
node
.
addWidget
(
"
button
"
,
inputName
,
"
hello
"
,
()
=>
{}),
};
});
// Register our extension that adds a custom node + widget type
const
mockExtension
=
{
name
:
"
TestExtension
"
,
addCustomNodeDefs
:
(
nodeDefs
)
=>
{
nodeDefs
[
"
TestNode
"
]
=
{
output
:
[],
output_name
:
[],
output_is_list
:
[],
name
:
"
TestNode
"
,
display_name
:
"
TestNode
"
,
category
:
"
Test
"
,
input
:
{
required
:
{
test_input
:
[
"
CUSTOMWIDGET
"
,
{
hello
:
"
world
"
}],
},
},
};
},
getCustomWidgets
:
jest
.
fn
(()
=>
{
return
{
CUSTOMWIDGET
:
widgetMock
,
};
}),
};
const
{
graph
,
ez
}
=
await
start
({
async
preSetup
(
app
)
{
app
.
registerExtension
(
mockExtension
);
},
});
expect
(
mockExtension
.
getCustomWidgets
).
toBeCalledTimes
(
1
);
graph
.
clear
();
expect
(
widgetMock
).
toBeCalledTimes
(
0
);
const
node
=
ez
.
TestNode
();
expect
(
widgetMock
).
toBeCalledTimes
(
1
);
// Ensure our custom widget is created
expect
(
node
.
inputs
.
length
).
toBe
(
0
);
expect
(
node
.
widgets
.
length
).
toBe
(
1
);
const
w
=
node
.
widgets
[
0
].
widget
;
expect
(
w
.
name
).
toBe
(
"
test_input
"
);
expect
(
w
.
type
).
toBe
(
"
button
"
);
});
});
tests-ui/tests/groupNode.test.js
0 → 100644
View file @
c92f3dca
This diff is collapsed.
Click to expand it.
tests-ui/tests/widgetInputs.test.js
View file @
c92f3dca
...
@@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
...
@@ -14,10 +14,10 @@ const lg = require("../utils/litegraph");
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzGraph"]> } graph
* @param { InstanceType<Ez["EzInput"]> } input
* @param { InstanceType<Ez["EzInput"]> } input
* @param { string } widgetType
* @param { string } widgetType
* @param {
boolean } hasC
ontrolWidget
* @param {
number } c
ontrolWidget
Count
* @returns
* @returns
*/
*/
async
function
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
widgetType
,
hasC
ontrolWidget
)
{
async
function
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
widgetType
,
c
ontrolWidget
Count
=
0
)
{
// Connect to primitive and ensure its still connected after
// Connect to primitive and ensure its still connected after
let
primitive
=
ez
.
PrimitiveNode
();
let
primitive
=
ez
.
PrimitiveNode
();
primitive
.
outputs
[
0
].
connectTo
(
input
);
primitive
.
outputs
[
0
].
connectTo
(
input
);
...
@@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
...
@@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro
expect
(
valueWidget
.
widget
.
type
).
toBe
(
widgetType
);
expect
(
valueWidget
.
widget
.
type
).
toBe
(
widgetType
);
// Check if control_after_generate should be added
// Check if control_after_generate should be added
if
(
hasC
ontrolWidget
)
{
if
(
c
ontrolWidget
Count
)
{
const
controlWidget
=
primitive
.
widgets
.
control_after_generate
;
const
controlWidget
=
primitive
.
widgets
.
control_after_generate
;
expect
(
controlWidget
.
widget
.
type
).
toBe
(
"
combo
"
);
expect
(
controlWidget
.
widget
.
type
).
toBe
(
"
combo
"
);
if
(
widgetType
===
"
combo
"
)
{
const
filterWidget
=
primitive
.
widgets
.
control_filter_list
;
expect
(
filterWidget
.
widget
.
type
).
toBe
(
"
string
"
);
}
}
}
// Ensure we dont have other widgets
// Ensure we dont have other widgets
expect
(
primitive
.
node
.
widgets
).
toHaveLength
(
1
+
+!!
hasC
ontrolWidget
);
expect
(
primitive
.
node
.
widgets
).
toHaveLength
(
1
+
c
ontrolWidget
Count
);
});
});
return
primitive
;
return
primitive
;
...
@@ -55,8 +59,8 @@ describe("widget inputs", () => {
...
@@ -55,8 +59,8 @@ describe("widget inputs", () => {
});
});
[
[
{
name
:
"
int
"
,
type
:
"
INT
"
,
widget
:
"
number
"
,
control
:
true
},
{
name
:
"
int
"
,
type
:
"
INT
"
,
widget
:
"
number
"
,
control
:
1
},
{
name
:
"
float
"
,
type
:
"
FLOAT
"
,
widget
:
"
number
"
,
control
:
true
},
{
name
:
"
float
"
,
type
:
"
FLOAT
"
,
widget
:
"
number
"
,
control
:
1
},
{
name
:
"
text
"
,
type
:
"
STRING
"
},
{
name
:
"
text
"
,
type
:
"
STRING
"
},
{
{
name
:
"
customtext
"
,
name
:
"
customtext
"
,
...
@@ -64,7 +68,7 @@ describe("widget inputs", () => {
...
@@ -64,7 +68,7 @@ describe("widget inputs", () => {
opt
:
{
multiline
:
true
},
opt
:
{
multiline
:
true
},
},
},
{
name
:
"
toggle
"
,
type
:
"
BOOLEAN
"
},
{
name
:
"
toggle
"
,
type
:
"
BOOLEAN
"
},
{
name
:
"
combo
"
,
type
:
[
"
a
"
,
"
b
"
,
"
c
"
],
control
:
true
},
{
name
:
"
combo
"
,
type
:
[
"
a
"
,
"
b
"
,
"
c
"
],
control
:
2
},
].
forEach
((
c
)
=>
{
].
forEach
((
c
)
=>
{
test
(
`widget conversion + primitive works on
${
c
.
name
}
`
,
async
()
=>
{
test
(
`widget conversion + primitive works on
${
c
.
name
}
`
,
async
()
=>
{
const
{
ez
,
graph
}
=
await
start
({
const
{
ez
,
graph
}
=
await
start
({
...
@@ -106,7 +110,7 @@ describe("widget inputs", () => {
...
@@ -106,7 +110,7 @@ describe("widget inputs", () => {
n
.
widgets
.
ckpt_name
.
convertToInput
();
n
.
widgets
.
ckpt_name
.
convertToInput
();
expect
(
n
.
inputs
.
length
).
toEqual
(
inputCount
+
1
);
expect
(
n
.
inputs
.
length
).
toEqual
(
inputCount
+
1
);
const
primitive
=
await
connectPrimitiveAndReload
(
ez
,
graph
,
n
.
inputs
.
ckpt_name
,
"
combo
"
,
true
);
const
primitive
=
await
connectPrimitiveAndReload
(
ez
,
graph
,
n
.
inputs
.
ckpt_name
,
"
combo
"
,
2
);
// Disconnect & reconnect
// Disconnect & reconnect
primitive
.
outputs
[
0
].
connections
[
0
].
disconnect
();
primitive
.
outputs
[
0
].
connections
[
0
].
disconnect
();
...
@@ -198,8 +202,8 @@ describe("widget inputs", () => {
...
@@ -198,8 +202,8 @@ describe("widget inputs", () => {
});
});
expect
(
dialogShow
).
toBeCalledTimes
(
1
);
expect
(
dialogShow
).
toBeCalledTimes
(
1
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]).
toContain
(
"
the following node types were not found
"
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]
.
innerHTML
).
toContain
(
"
the following node types were not found
"
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]).
toContain
(
"
TestNode
"
);
expect
(
dialogShow
.
mock
.
calls
[
0
][
0
]
.
innerHTML
).
toContain
(
"
TestNode
"
);
});
});
test
(
"
defaultInput widgets can be converted back to inputs
"
,
async
()
=>
{
test
(
"
defaultInput widgets can be converted back to inputs
"
,
async
()
=>
{
...
@@ -226,7 +230,7 @@ describe("widget inputs", () => {
...
@@ -226,7 +230,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
// Reload and ensure it still only has 1 converted widget
if
(
!
assertNotNullOrUndefined
(
input
))
return
;
if
(
!
assertNotNullOrUndefined
(
input
))
return
;
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
true
);
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
1
);
n
=
graph
.
find
(
n
);
n
=
graph
.
find
(
n
);
expect
(
n
.
widgets
).
toHaveLength
(
1
);
expect
(
n
.
widgets
).
toHaveLength
(
1
);
w
=
n
.
widgets
.
example
;
w
=
n
.
widgets
.
example
;
...
@@ -258,7 +262,7 @@ describe("widget inputs", () => {
...
@@ -258,7 +262,7 @@ describe("widget inputs", () => {
// Reload and ensure it still only has 1 converted widget
// Reload and ensure it still only has 1 converted widget
if
(
assertNotNullOrUndefined
(
input
))
{
if
(
assertNotNullOrUndefined
(
input
))
{
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
true
);
await
connectPrimitiveAndReload
(
ez
,
graph
,
input
,
"
number
"
,
1
);
n
=
graph
.
find
(
n
);
n
=
graph
.
find
(
n
);
expect
(
n
.
widgets
).
toHaveLength
(
1
);
expect
(
n
.
widgets
).
toHaveLength
(
1
);
expect
(
n
.
widgets
.
example
.
isConvertedToInput
).
toBeTruthy
();
expect
(
n
.
widgets
.
example
.
isConvertedToInput
).
toBeTruthy
();
...
@@ -316,4 +320,76 @@ describe("widget inputs", () => {
...
@@ -316,4 +320,76 @@ describe("widget inputs", () => {
n1
.
outputs
[
0
].
connectTo
(
n2
.
inputs
[
0
]);
n1
.
outputs
[
0
].
connectTo
(
n2
.
inputs
[
0
]);
expect
(()
=>
n1
.
outputs
[
0
].
connectTo
(
n3
.
inputs
[
0
])).
toThrow
();
expect
(()
=>
n1
.
outputs
[
0
].
connectTo
(
n3
.
inputs
[
0
])).
toThrow
();
});
});
test
(
"
combo primitive can filter list when control_after_generate called
"
,
async
()
=>
{
const
{
ez
}
=
await
start
({
mockNodeDefs
:
{
...
makeNodeDef
(
"
TestNode1
"
,
{
example
:
[[
"
A
"
,
"
B
"
,
"
C
"
,
"
D
"
,
"
AA
"
,
"
BB
"
,
"
CC
"
,
"
DD
"
,
"
AAA
"
,
"
BBB
"
],
{}]
}),
},
});
const
n1
=
ez
.
TestNode1
();
n1
.
widgets
.
example
.
convertToInput
();
const
p
=
ez
.
PrimitiveNode
()
p
.
outputs
[
0
].
connectTo
(
n1
.
inputs
[
0
]);
const
value
=
p
.
widgets
.
value
;
const
control
=
p
.
widgets
.
control_after_generate
.
widget
;
const
filter
=
p
.
widgets
.
control_filter_list
;
expect
(
p
.
widgets
.
length
).
toBe
(
3
);
control
.
value
=
"
increment
"
;
expect
(
value
.
value
).
toBe
(
"
A
"
);
// Manually trigger after queue when set to increment
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
// Filter to items containing D
filter
.
value
=
"
D
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
D
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
DD
"
);
// Check decrement
value
.
value
=
"
BBB
"
;
control
.
value
=
"
decrement
"
;
filter
.
value
=
"
B
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
BB
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
// Check regex works
value
.
value
=
"
BBB
"
;
filter
.
value
=
"
/[AB]|^C$/
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
AAA
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
BB
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
AA
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
C
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
A
"
);
// Check random
control
.
value
=
"
randomize
"
;
filter
.
value
=
"
/D/
"
;
for
(
let
i
=
0
;
i
<
100
;
i
++
)
{
control
[
"
afterQueued
"
]();
expect
(
value
.
value
===
"
D
"
||
value
.
value
===
"
DD
"
).
toBeTruthy
();
}
// Ensure it doesnt apply when fixed
control
.
value
=
"
fixed
"
;
value
.
value
=
"
B
"
;
filter
.
value
=
"
C
"
;
control
[
"
afterQueued
"
]();
expect
(
value
.
value
).
toBe
(
"
B
"
);
});
});
});
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment