Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
04e8798c
Commit
04e8798c
authored
Jun 16, 2024
by
comfyanonymous
Browse files
Improvements to the TAESD3 implementation.
parent
df7db0e0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
9 deletions
+10
-9
comfy/sd.py
comfy/sd.py
+3
-3
comfy/taesd/taesd.py
comfy/taesd/taesd.py
+3
-2
nodes.py
nodes.py
+4
-4
No files found.
comfy/sd.py
View file @
04e8798c
...
@@ -166,7 +166,7 @@ class CLIP:
...
@@ -166,7 +166,7 @@ class CLIP:
return
self
.
patcher
.
get_key_patches
()
return
self
.
patcher
.
get_key_patches
()
class
VAE
:
class
VAE
:
def
__init__
(
self
,
sd
=
None
,
device
=
None
,
config
=
None
,
dtype
=
None
,
latent_channels
=
4
):
def
__init__
(
self
,
sd
=
None
,
device
=
None
,
config
=
None
,
dtype
=
None
):
if
'decoder.up_blocks.0.resnets.0.norm1.weight'
in
sd
.
keys
():
#diffusers format
if
'decoder.up_blocks.0.resnets.0.norm1.weight'
in
sd
.
keys
():
#diffusers format
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
...
@@ -174,7 +174,7 @@ class VAE:
...
@@ -174,7 +174,7 @@ class VAE:
self
.
memory_used_decode
=
lambda
shape
,
dtype
:
(
2178
*
shape
[
2
]
*
shape
[
3
]
*
64
)
*
model_management
.
dtype_size
(
dtype
)
self
.
memory_used_decode
=
lambda
shape
,
dtype
:
(
2178
*
shape
[
2
]
*
shape
[
3
]
*
64
)
*
model_management
.
dtype_size
(
dtype
)
self
.
downscale_ratio
=
8
self
.
downscale_ratio
=
8
self
.
upscale_ratio
=
8
self
.
upscale_ratio
=
8
self
.
latent_channels
=
latent_channels
self
.
latent_channels
=
4
self
.
output_channels
=
3
self
.
output_channels
=
3
self
.
process_input
=
lambda
image
:
image
*
2.0
-
1.0
self
.
process_input
=
lambda
image
:
image
*
2.0
-
1.0
self
.
process_output
=
lambda
image
:
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
self
.
process_output
=
lambda
image
:
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
...
@@ -189,7 +189,7 @@ class VAE:
...
@@ -189,7 +189,7 @@ class VAE:
encoder_config
=
{
'target'
:
"comfy.ldm.modules.diffusionmodules.model.Encoder"
,
'params'
:
encoder_config
},
encoder_config
=
{
'target'
:
"comfy.ldm.modules.diffusionmodules.model.Encoder"
,
'params'
:
encoder_config
},
decoder_config
=
{
'target'
:
"comfy.ldm.modules.temporal_ae.VideoDecoder"
,
'params'
:
decoder_config
})
decoder_config
=
{
'target'
:
"comfy.ldm.modules.temporal_ae.VideoDecoder"
,
'params'
:
decoder_config
})
elif
"taesd_decoder.1.weight"
in
sd
:
elif
"taesd_decoder.1.weight"
in
sd
:
self
.
first_stage_model
=
comfy
.
taesd
.
taesd
.
TAESD
(
latent_channels
=
s
elf
.
latent_channels
)
self
.
first_stage_model
=
comfy
.
taesd
.
taesd
.
TAESD
(
latent_channels
=
s
d
[
"taesd_decoder.1.weight"
].
shape
[
1
]
)
elif
"vquantizer.codebook.weight"
in
sd
:
#VQGan: stage a of stable cascade
elif
"vquantizer.codebook.weight"
in
sd
:
#VQGan: stage a of stable cascade
self
.
first_stage_model
=
StageA
()
self
.
first_stage_model
=
StageA
()
self
.
downscale_ratio
=
4
self
.
downscale_ratio
=
4
...
...
comfy/taesd/taesd.py
View file @
04e8798c
...
@@ -54,6 +54,7 @@ class TAESD(nn.Module):
...
@@ -54,6 +54,7 @@ class TAESD(nn.Module):
self
.
taesd_encoder
=
Encoder
(
latent_channels
=
latent_channels
)
self
.
taesd_encoder
=
Encoder
(
latent_channels
=
latent_channels
)
self
.
taesd_decoder
=
Decoder
(
latent_channels
=
latent_channels
)
self
.
taesd_decoder
=
Decoder
(
latent_channels
=
latent_channels
)
self
.
vae_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
))
self
.
vae_scale
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
1.0
))
self
.
vae_shift
=
torch
.
nn
.
Parameter
(
torch
.
tensor
(
0.0
))
if
encoder_path
is
not
None
:
if
encoder_path
is
not
None
:
self
.
taesd_encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
self
.
taesd_encoder
.
load_state_dict
(
comfy
.
utils
.
load_torch_file
(
encoder_path
,
safe_load
=
True
))
if
decoder_path
is
not
None
:
if
decoder_path
is
not
None
:
...
@@ -70,9 +71,9 @@ class TAESD(nn.Module):
...
@@ -70,9 +71,9 @@ class TAESD(nn.Module):
return
x
.
sub
(
TAESD
.
latent_shift
).
mul
(
2
*
TAESD
.
latent_magnitude
)
return
x
.
sub
(
TAESD
.
latent_shift
).
mul
(
2
*
TAESD
.
latent_magnitude
)
def
decode
(
self
,
x
):
def
decode
(
self
,
x
):
x_sample
=
self
.
taesd_decoder
(
x
*
self
.
vae_scale
)
x_sample
=
self
.
taesd_decoder
(
(
x
-
self
.
vae_shift
)
*
self
.
vae_scale
)
x_sample
=
x_sample
.
sub
(
0.5
).
mul
(
2
)
x_sample
=
x_sample
.
sub
(
0.5
).
mul
(
2
)
return
x_sample
return
x_sample
def
encode
(
self
,
x
):
def
encode
(
self
,
x
):
return
self
.
taesd_encoder
(
x
*
0.5
+
0.5
)
/
self
.
vae_scale
return
(
self
.
taesd_encoder
(
x
*
0.5
+
0.5
)
/
self
.
vae_scale
)
+
self
.
vae_shift
nodes.py
View file @
04e8798c
...
@@ -676,10 +676,13 @@ class VAELoader:
...
@@ -676,10 +676,13 @@ class VAELoader:
if
name
==
"taesd"
:
if
name
==
"taesd"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.18215
)
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.18215
)
sd
[
"vae_shift"
]
=
torch
.
tensor
(
0.0
)
elif
name
==
"taesdxl"
:
elif
name
==
"taesdxl"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.13025
)
sd
[
"vae_scale"
]
=
torch
.
tensor
(
0.13025
)
sd
[
"vae_shift"
]
=
torch
.
tensor
(
0.0
)
elif
name
==
"taesd3"
:
elif
name
==
"taesd3"
:
sd
[
"vae_scale"
]
=
torch
.
tensor
(
1.5305
)
sd
[
"vae_scale"
]
=
torch
.
tensor
(
1.5305
)
sd
[
"vae_shift"
]
=
torch
.
tensor
(
0.0609
)
return
sd
return
sd
@
classmethod
@
classmethod
...
@@ -697,10 +700,7 @@ class VAELoader:
...
@@ -697,10 +700,7 @@ class VAELoader:
else
:
else
:
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
)
latent_channels
=
16
if
vae_name
==
'taesd3'
else
4
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
,
latent_channels
=
latent_channels
)
return
(
vae
,)
return
(
vae
,)
class
ControlNetLoader
:
class
ControlNetLoader
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment