Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
df7db0e0
Unverified
Commit
df7db0e0
authored
Jun 16, 2024
by
Dr.Lt.Data
Committed by
GitHub
Jun 16, 2024
Browse files
support TAESD3 (#3738)
parent
bb1969ca
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
13 deletions
+28
-13
comfy/latent_formats.py
comfy/latent_formats.py
+1
-0
comfy/sd.py
comfy/sd.py
+3
-3
comfy/taesd/taesd.py
comfy/taesd/taesd.py
+8
-7
latent_preview.py
latent_preview.py
+1
-1
nodes.py
nodes.py
+15
-2
No files found.
comfy/latent_formats.py
View file @
df7db0e0
...
@@ -129,6 +129,7 @@ class SD3(LatentFormat):
...
@@ -129,6 +129,7 @@ class SD3(LatentFormat):
[
-
0.0749
,
-
0.0634
,
-
0.0456
],
[
-
0.0749
,
-
0.0634
,
-
0.0456
],
[
-
0.1418
,
-
0.1457
,
-
0.1259
]
[
-
0.1418
,
-
0.1457
,
-
0.1259
]
]
]
self
.
taesd_decoder_name
=
"taesd3_decoder"
def
process_in
(
self
,
latent
):
def
process_in
(
self
,
latent
):
return
(
latent
-
self
.
shift_factor
)
*
self
.
scale_factor
return
(
latent
-
self
.
shift_factor
)
*
self
.
scale_factor
...
...
comfy/sd.py
View file @
df7db0e0
...
@@ -166,7 +166,7 @@ class CLIP:
...
@@ -166,7 +166,7 @@ class CLIP:
return
self
.
patcher
.
get_key_patches
()
return
self
.
patcher
.
get_key_patches
()
class
VAE
:
class
VAE
:
def
__init__
(
self
,
sd
=
None
,
device
=
None
,
config
=
None
,
dtype
=
None
):
def
__init__
(
self
,
sd
=
None
,
device
=
None
,
config
=
None
,
dtype
=
None
,
latent_channels
=
4
):
if
'decoder.up_blocks.0.resnets.0.norm1.weight'
in
sd
.
keys
():
#diffusers format
if
'decoder.up_blocks.0.resnets.0.norm1.weight'
in
sd
.
keys
():
#diffusers format
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
...
@@ -174,7 +174,7 @@ class VAE:
...
@@ -174,7 +174,7 @@ class VAE:
self
.
memory_used_decode
=
lambda
shape
,
dtype
:
(
2178
*
shape
[
2
]
*
shape
[
3
]
*
64
)
*
model_management
.
dtype_size
(
dtype
)
self
.
memory_used_decode
=
lambda
shape
,
dtype
:
(
2178
*
shape
[
2
]
*
shape
[
3
]
*
64
)
*
model_management
.
dtype_size
(
dtype
)
self
.
downscale_ratio
=
8
self
.
downscale_ratio
=
8
self
.
upscale_ratio
=
8
self
.
upscale_ratio
=
8
self
.
latent_channels
=
4
self
.
latent_channels
=
latent_channels
self
.
output_channels
=
3
self
.
output_channels
=
3
self
.
process_input
=
lambda
image
:
image
*
2.0
-
1.0
self
.
process_input
=
lambda
image
:
image
*
2.0
-
1.0
self
.
process_output
=
lambda
image
:
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
self
.
process_output
=
lambda
image
:
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
...
@@ -189,7 +189,7 @@ class VAE:
...
@@ -189,7 +189,7 @@ class VAE:
encoder_config
=
{
'target'
:
"comfy.ldm.modules.diffusionmodules.model.Encoder"
,
'params'
:
encoder_config
},
encoder_config
=
{
'target'
:
"comfy.ldm.modules.diffusionmodules.model.Encoder"
,
'params'
:
encoder_config
},
decoder_config
=
{
'target'
:
"comfy.ldm.modules.temporal_ae.VideoDecoder"
,
'params'
:
decoder_config
})
decoder_config
=
{
'target'
:
"comfy.ldm.modules.temporal_ae.VideoDecoder"
,
'params'
:
decoder_config
})
elif
"taesd_decoder.1.weight"
in
sd
:
elif
"taesd_decoder.1.weight"
in
sd
:
self
.
first_stage_model
=
comfy
.
taesd
.
taesd
.
TAESD
()
self
.
first_stage_model
=
comfy
.
taesd
.
taesd
.
TAESD
(
latent_channels
=
self
.
latent_channels
)
elif
"vquantizer.codebook.weight"
in
sd
:
#VQGan: stage a of stable cascade
elif
"vquantizer.codebook.weight"
in
sd
:
#VQGan: stage a of stable cascade
self
.
first_stage_model
=
StageA
()
self
.
first_stage_model
=
StageA
()
self
.
downscale_ratio
=
4
self
.
downscale_ratio
=
4
...
...
comfy/taesd/taesd.py
View file @
df7db0e0
...
@@ -25,18 +25,19 @@ class Block(nn.Module):
...
@@ -25,18 +25,19 @@ class Block(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
fuse
(
self
.
conv
(
x
)
+
self
.
skip
(
x
))
return
self
.
fuse
(
self
.
conv
(
x
)
+
self
.
skip
(
x
))
def
Encoder
():
def
Encoder
(
latent_channels
=
4
):
return
nn
.
Sequential
(
return
nn
.
Sequential
(
conv
(
3
,
64
),
Block
(
64
,
64
),
conv
(
3
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
4
),
conv
(
64
,
latent_channels
),
)
)
def
Decoder
():
def
Decoder
(
latent_channels
=
4
):
return
nn
.
Sequential
(
return
nn
.
Sequential
(
Clamp
(),
conv
(
4
,
64
),
nn
.
ReLU
(),
Clamp
(),
conv
(
latent_channels
,
64
),
nn
.
ReLU
(),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
...
@@ -47,11 +48,11 @@ class TAESD(nn.Module):
...
@@ -47,11 +48,11 @@ class TAESD(nn.Module):
latent_magnitude
=
3
latent_magnitude
=
3
latent_shift
=
0.5
latent_shift
=
0.5
def
__init__
(
self
,
encoder_path
=
None
,
decoder_path
=
None
):
def
__init__
(
self
,
encoder_path
=
None
,
decoder_path
=
None
,
latent_channels
=
4
):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super
().
__init__
()
super
().
__init__
()
self
.
taesd_encoder
=
Encoder
()
self
.
taesd_encoder
=
Encoder
(
latent_channels
=
latent_channels
)
self
.
taesd_decoder
=
Decoder
()
self
.
taesd_decoder
=
Decoder
(
latent_channels
=
latent_channels
)
self
.
vae_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
))
self
.
vae_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
))
if
encoder_path
is
not
None
:
if
encoder_path
is
not
None
:
self
.
taesd_encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
self
.
taesd_encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
...
...
latent_preview.py
View file @
df7db0e0
...
@@ -64,7 +64,7 @@ def get_previewer(device, latent_format):
...
@@ -64,7 +64,7 @@ def get_previewer(device, latent_format):
if
method
==
LatentPreviewMethod
.
TAESD
:
if
method
==
LatentPreviewMethod
.
TAESD
:
if
taesd_decoder_path
:
if
taesd_decoder_path
:
taesd
=
TAESD
(
None
,
taesd_decoder_path
).
to
(
device
)
taesd
=
TAESD
(
None
,
taesd_decoder_path
,
latent_channels
=
latent_format
.
latent_channels
).
to
(
device
)
previewer
=
TAESDPreviewerImpl
(
taesd
)
previewer
=
TAESDPreviewerImpl
(
taesd
)
else
:
else
:
logging
.
warning
(
"Warning: TAESD previews enabled, but could not find models/vae_approx/{}"
.
format
(
latent_format
.
taesd_decoder_name
))
logging
.
warning
(
"Warning: TAESD previews enabled, but could not find models/vae_approx/{}"
.
format
(
latent_format
.
taesd_decoder_name
))
...
...
nodes.py
View file @
df7db0e0
...
@@ -634,6 +634,8 @@ class VAELoader:
...
@@ -634,6 +634,8 @@ class VAELoader:
sdxl_taesd_dec
=
False
sdxl_taesd_dec
=
False
sd1_taesd_enc
=
False
sd1_taesd_enc
=
False
sd1_taesd_dec
=
False
sd1_taesd_dec
=
False
sd3_taesd_enc
=
False
sd3_taesd_dec
=
False
for
v
in
approx_vaes
:
for
v
in
approx_vaes
:
if
v
.
startswith
(
"taesd_decoder."
):
if
v
.
startswith
(
"taesd_decoder."
):
...
@@ -644,10 +646,16 @@ class VAELoader:
...
@@ -644,10 +646,16 @@ class VAELoader:
sdxl_taesd_dec
=
True
sdxl_taesd_dec
=
True
elif
v
.
startswith
(
"taesdxl_encoder."
):
elif
v
.
startswith
(
"taesdxl_encoder."
):
sdxl_taesd_enc
=
True
sdxl_taesd_enc
=
True
elif
v
.
startswith
(
"taesd3_decoder."
):
sd3_taesd_dec
=
True
elif
v
.
startswith
(
"taesd3_encoder."
):
sd3_taesd_enc
=
True
if
sd1_taesd_dec
and
sd1_taesd_enc
:
if
sd1_taesd_dec
and
sd1_taesd_enc
:
vaes
.
append
(
"taesd"
)
vaes
.
append
(
"taesd"
)
if
sdxl_taesd_dec
and
sdxl_taesd_enc
:
if
sdxl_taesd_dec
and
sdxl_taesd_enc
:
vaes
.
append
(
"taesdxl"
)
vaes
.
append
(
"taesdxl"
)
if
sd3_taesd_dec
and
sd3_taesd_enc
:
vaes
.
append
(
"taesd3"
)
return
vaes
return
vaes
@
staticmethod
@
staticmethod
...
@@ -670,6 +678,8 @@ class VAELoader:
...
@@ -670,6 +678,8 @@ class VAELoader:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.18215
)
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.18215
)
elif
name
==
"taesdxl"
:
elif
name
==
"taesdxl"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.13025
)
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.13025
)
elif
name
==
"taesd3"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
1.5305
)
return
sd
return
sd
@
classmethod
@
classmethod
...
@@ -682,12 +692,15 @@ class VAELoader:
...
@@ -682,12 +692,15 @@ class VAELoader:
#TODO: scale factor?
#TODO: scale factor?
def
load_vae
(
self
,
vae_name
):
def
load_vae
(
self
,
vae_name
):
if
vae_name
in
[
"taesd"
,
"taesdxl"
]:
if
vae_name
in
[
"taesd"
,
"taesdxl"
,
"taesd3"
]:
sd
=
self
.
load_taesd
(
vae_name
)
sd
=
self
.
load_taesd
(
vae_name
)
else
:
else
:
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
)
latent_channels
=
16
if
vae_name
==
'taesd3'
else
4
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
,
latent_channels
=
latent_channels
)
return
(
vae
,)
return
(
vae
,)
class
ControlNetLoader
:
class
ControlNetLoader
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment