Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
5e06baf1
Commit
5e06baf1
authored
Feb 16, 2024
by
comfyanonymous
Browse files
Stable Cascade Stage A.
parent
c2c88526
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
272 additions
and
6 deletions
+272
-6
comfy/ldm/cascade/stage_a.py
comfy/ldm/cascade/stage_a.py
+254
-0
comfy/sd.py
comfy/sd.py
+18
-6
No files found.
comfy/ldm/cascade/stage_a.py
0 → 100644
View file @
5e06baf1
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
from
torch
import
nn
from
torch.autograd
import
Function
class
vector_quantize
(
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
codebook
):
with
torch
.
no_grad
():
codebook_sqr
=
torch
.
sum
(
codebook
**
2
,
dim
=
1
)
x_sqr
=
torch
.
sum
(
x
**
2
,
dim
=
1
,
keepdim
=
True
)
dist
=
torch
.
addmm
(
codebook_sqr
+
x_sqr
,
x
,
codebook
.
t
(),
alpha
=-
2.0
,
beta
=
1.0
)
_
,
indices
=
dist
.
min
(
dim
=
1
)
ctx
.
save_for_backward
(
indices
,
codebook
)
ctx
.
mark_non_differentiable
(
indices
)
nn
=
torch
.
index_select
(
codebook
,
0
,
indices
)
return
nn
,
indices
@
staticmethod
def
backward
(
ctx
,
grad_output
,
grad_indices
):
grad_inputs
,
grad_codebook
=
None
,
None
if
ctx
.
needs_input_grad
[
0
]:
grad_inputs
=
grad_output
.
clone
()
if
ctx
.
needs_input_grad
[
1
]:
# Gradient wrt. the codebook
indices
,
codebook
=
ctx
.
saved_tensors
grad_codebook
=
torch
.
zeros_like
(
codebook
)
grad_codebook
.
index_add_
(
0
,
indices
,
grad_output
)
return
(
grad_inputs
,
grad_codebook
)
class
VectorQuantize
(
nn
.
Module
):
def
__init__
(
self
,
embedding_size
,
k
,
ema_decay
=
0.99
,
ema_loss
=
False
):
"""
Takes an input of variable size (as long as the last dimension matches the embedding size).
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
with the same size as the input, vq and commitment components for the loss as a touple
in the second output and the indices of the quantized vectors in the third:
quantized, (vq_loss, commit_loss), indices
"""
super
(
VectorQuantize
,
self
).
__init__
()
self
.
codebook
=
nn
.
Embedding
(
k
,
embedding_size
)
self
.
codebook
.
weight
.
data
.
uniform_
(
-
1.
/
k
,
1.
/
k
)
self
.
vq
=
vector_quantize
.
apply
self
.
ema_decay
=
ema_decay
self
.
ema_loss
=
ema_loss
if
ema_loss
:
self
.
register_buffer
(
'ema_element_count'
,
torch
.
ones
(
k
))
self
.
register_buffer
(
'ema_weight_sum'
,
torch
.
zeros_like
(
self
.
codebook
.
weight
))
def
_laplace_smoothing
(
self
,
x
,
epsilon
):
n
=
torch
.
sum
(
x
)
return
((
x
+
epsilon
)
/
(
n
+
x
.
size
(
0
)
*
epsilon
)
*
n
)
def
_updateEMA
(
self
,
z_e_x
,
indices
):
mask
=
nn
.
functional
.
one_hot
(
indices
,
self
.
ema_element_count
.
size
(
0
)).
float
()
elem_count
=
mask
.
sum
(
dim
=
0
)
weight_sum
=
torch
.
mm
(
mask
.
t
(),
z_e_x
)
self
.
ema_element_count
=
(
self
.
ema_decay
*
self
.
ema_element_count
)
+
((
1
-
self
.
ema_decay
)
*
elem_count
)
self
.
ema_element_count
=
self
.
_laplace_smoothing
(
self
.
ema_element_count
,
1e-5
)
self
.
ema_weight_sum
=
(
self
.
ema_decay
*
self
.
ema_weight_sum
)
+
((
1
-
self
.
ema_decay
)
*
weight_sum
)
self
.
codebook
.
weight
.
data
=
self
.
ema_weight_sum
/
self
.
ema_element_count
.
unsqueeze
(
-
1
)
def
idx2vq
(
self
,
idx
,
dim
=-
1
):
q_idx
=
self
.
codebook
(
idx
)
if
dim
!=
-
1
:
q_idx
=
q_idx
.
movedim
(
-
1
,
dim
)
return
q_idx
def
forward
(
self
,
x
,
get_losses
=
True
,
dim
=-
1
):
if
dim
!=
-
1
:
x
=
x
.
movedim
(
dim
,
-
1
)
z_e_x
=
x
.
contiguous
().
view
(
-
1
,
x
.
size
(
-
1
))
if
len
(
x
.
shape
)
>
2
else
x
z_q_x
,
indices
=
self
.
vq
(
z_e_x
,
self
.
codebook
.
weight
.
detach
())
vq_loss
,
commit_loss
=
None
,
None
if
self
.
ema_loss
and
self
.
training
:
self
.
_updateEMA
(
z_e_x
.
detach
(),
indices
.
detach
())
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
z_q_x_grd
=
torch
.
index_select
(
self
.
codebook
.
weight
,
dim
=
0
,
index
=
indices
)
if
get_losses
:
vq_loss
=
(
z_q_x_grd
-
z_e_x
.
detach
()).
pow
(
2
).
mean
()
commit_loss
=
(
z_e_x
-
z_q_x_grd
.
detach
()).
pow
(
2
).
mean
()
z_q_x
=
z_q_x
.
view
(
x
.
shape
)
if
dim
!=
-
1
:
z_q_x
=
z_q_x
.
movedim
(
-
1
,
dim
)
return
z_q_x
,
(
vq_loss
,
commit_loss
),
indices
.
view
(
x
.
shape
[:
-
1
])
class
ResBlock
(
nn
.
Module
):
def
__init__
(
self
,
c
,
c_hidden
):
super
().
__init__
()
# depthwise/attention
self
.
norm1
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
depthwise
=
nn
.
Sequential
(
nn
.
ReplicationPad2d
(
1
),
nn
.
Conv2d
(
c
,
c
,
kernel_size
=
3
,
groups
=
c
)
)
# channelwise
self
.
norm2
=
nn
.
LayerNorm
(
c
,
elementwise_affine
=
False
,
eps
=
1e-6
)
self
.
channelwise
=
nn
.
Sequential
(
nn
.
Linear
(
c
,
c_hidden
),
nn
.
GELU
(),
nn
.
Linear
(
c_hidden
,
c
),
)
self
.
gammas
=
nn
.
Parameter
(
torch
.
zeros
(
6
),
requires_grad
=
True
)
# Init weights
def
_basic_init
(
module
):
if
isinstance
(
module
,
nn
.
Linear
)
or
isinstance
(
module
,
nn
.
Conv2d
):
torch
.
nn
.
init
.
xavier_uniform_
(
module
.
weight
)
if
module
.
bias
is
not
None
:
nn
.
init
.
constant_
(
module
.
bias
,
0
)
self
.
apply
(
_basic_init
)
def
_norm
(
self
,
x
,
norm
):
return
norm
(
x
.
permute
(
0
,
2
,
3
,
1
)).
permute
(
0
,
3
,
1
,
2
)
def
forward
(
self
,
x
):
mods
=
self
.
gammas
x_temp
=
self
.
_norm
(
x
,
self
.
norm1
)
*
(
1
+
mods
[
0
])
+
mods
[
1
]
x
=
x
+
self
.
depthwise
(
x_temp
)
*
mods
[
2
]
x_temp
=
self
.
_norm
(
x
,
self
.
norm2
)
*
(
1
+
mods
[
3
])
+
mods
[
4
]
x
=
x
+
self
.
channelwise
(
x_temp
.
permute
(
0
,
2
,
3
,
1
)).
permute
(
0
,
3
,
1
,
2
)
*
mods
[
5
]
return
x
class
StageA
(
nn
.
Module
):
def
__init__
(
self
,
levels
=
2
,
bottleneck_blocks
=
12
,
c_hidden
=
384
,
c_latent
=
4
,
codebook_size
=
8192
,
scale_factor
=
0.43
):
# 0.3764
super
().
__init__
()
self
.
c_latent
=
c_latent
self
.
scale_factor
=
scale_factor
c_levels
=
[
c_hidden
//
(
2
**
i
)
for
i
in
reversed
(
range
(
levels
))]
# Encoder blocks
self
.
in_block
=
nn
.
Sequential
(
nn
.
PixelUnshuffle
(
2
),
nn
.
Conv2d
(
3
*
4
,
c_levels
[
0
],
kernel_size
=
1
)
)
down_blocks
=
[]
for
i
in
range
(
levels
):
if
i
>
0
:
down_blocks
.
append
(
nn
.
Conv2d
(
c_levels
[
i
-
1
],
c_levels
[
i
],
kernel_size
=
4
,
stride
=
2
,
padding
=
1
))
block
=
ResBlock
(
c_levels
[
i
],
c_levels
[
i
]
*
4
)
down_blocks
.
append
(
block
)
down_blocks
.
append
(
nn
.
Sequential
(
nn
.
Conv2d
(
c_levels
[
-
1
],
c_latent
,
kernel_size
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
c_latent
),
# then normalize them to have mean 0 and std 1
))
self
.
down_blocks
=
nn
.
Sequential
(
*
down_blocks
)
self
.
down_blocks
[
0
]
self
.
codebook_size
=
codebook_size
self
.
vquantizer
=
VectorQuantize
(
c_latent
,
k
=
codebook_size
)
# Decoder blocks
up_blocks
=
[
nn
.
Sequential
(
nn
.
Conv2d
(
c_latent
,
c_levels
[
-
1
],
kernel_size
=
1
)
)]
for
i
in
range
(
levels
):
for
j
in
range
(
bottleneck_blocks
if
i
==
0
else
1
):
block
=
ResBlock
(
c_levels
[
levels
-
1
-
i
],
c_levels
[
levels
-
1
-
i
]
*
4
)
up_blocks
.
append
(
block
)
if
i
<
levels
-
1
:
up_blocks
.
append
(
nn
.
ConvTranspose2d
(
c_levels
[
levels
-
1
-
i
],
c_levels
[
levels
-
2
-
i
],
kernel_size
=
4
,
stride
=
2
,
padding
=
1
))
self
.
up_blocks
=
nn
.
Sequential
(
*
up_blocks
)
self
.
out_block
=
nn
.
Sequential
(
nn
.
Conv2d
(
c_levels
[
0
],
3
*
4
,
kernel_size
=
1
),
nn
.
PixelShuffle
(
2
),
)
def
encode
(
self
,
x
,
quantize
=
False
):
x
=
self
.
in_block
(
x
)
x
=
self
.
down_blocks
(
x
)
if
quantize
:
qe
,
(
vq_loss
,
commit_loss
),
indices
=
self
.
vquantizer
.
forward
(
x
,
dim
=
1
)
return
qe
/
self
.
scale_factor
,
x
/
self
.
scale_factor
,
indices
,
vq_loss
+
commit_loss
*
0.25
else
:
return
x
/
self
.
scale_factor
def
decode
(
self
,
x
):
x
=
x
*
self
.
scale_factor
x
=
self
.
up_blocks
(
x
)
x
=
self
.
out_block
(
x
)
return
x
def
forward
(
self
,
x
,
quantize
=
False
):
qe
,
x
,
_
,
vq_loss
=
self
.
encode
(
x
,
quantize
)
x
=
self
.
decode
(
qe
)
return
x
,
vq_loss
class
Discriminator
(
nn
.
Module
):
def
__init__
(
self
,
c_in
=
3
,
c_cond
=
0
,
c_hidden
=
512
,
depth
=
6
):
super
().
__init__
()
d
=
max
(
depth
-
3
,
3
)
layers
=
[
nn
.
utils
.
spectral_norm
(
nn
.
Conv2d
(
c_in
,
c_hidden
//
(
2
**
d
),
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)),
nn
.
LeakyReLU
(
0.2
),
]
for
i
in
range
(
depth
-
1
):
c_in
=
c_hidden
//
(
2
**
max
((
d
-
i
),
0
))
c_out
=
c_hidden
//
(
2
**
max
((
d
-
1
-
i
),
0
))
layers
.
append
(
nn
.
utils
.
spectral_norm
(
nn
.
Conv2d
(
c_in
,
c_out
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)))
layers
.
append
(
nn
.
InstanceNorm2d
(
c_out
))
layers
.
append
(
nn
.
LeakyReLU
(
0.2
))
self
.
encoder
=
nn
.
Sequential
(
*
layers
)
self
.
shuffle
=
nn
.
Conv2d
((
c_hidden
+
c_cond
)
if
c_cond
>
0
else
c_hidden
,
1
,
kernel_size
=
1
)
self
.
logits
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
,
cond
=
None
):
x
=
self
.
encoder
(
x
)
if
cond
is
not
None
:
cond
=
cond
.
view
(
cond
.
size
(
0
),
cond
.
size
(
1
),
1
,
1
,
).
expand
(
-
1
,
-
1
,
x
.
size
(
-
2
),
x
.
size
(
-
1
))
x
=
torch
.
cat
([
x
,
cond
],
dim
=
1
)
x
=
self
.
shuffle
(
x
)
x
=
self
.
logits
(
x
)
return
x
comfy/sd.py
View file @
5e06baf1
...
...
@@ -2,6 +2,8 @@ import torch
from
comfy
import
model_management
from
.ldm.models.autoencoder
import
AutoencoderKL
,
AutoencodingEngine
from
.ldm.cascade.stage_a
import
StageA
import
yaml
import
comfy.utils
...
...
@@ -156,6 +158,8 @@ class VAE:
self
.
memory_used_decode
=
lambda
shape
,
dtype
:
(
2178
*
shape
[
2
]
*
shape
[
3
]
*
64
)
*
model_management
.
dtype_size
(
dtype
)
self
.
downscale_ratio
=
8
self
.
latent_channels
=
4
self
.
process_input
=
lambda
image
:
image
*
2.0
-
1.0
self
.
process_output
=
lambda
image
:
torch
.
clamp
((
image
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
if
config
is
None
:
if
"decoder.mid.block_1.mix_factor"
in
sd
:
...
...
@@ -168,6 +172,14 @@ class VAE:
decoder_config
=
{
'target'
:
"comfy.ldm.modules.temporal_ae.VideoDecoder"
,
'params'
:
decoder_config
})
elif
"taesd_decoder.1.weight"
in
sd
:
self
.
first_stage_model
=
comfy
.
taesd
.
taesd
.
TAESD
()
elif
"vquantizer.codebook.weight"
in
sd
:
#VQGan: stage a of stable cascade
self
.
first_stage_model
=
StageA
()
self
.
downscale_ratio
=
4
#TODO
#self.memory_used_encode
#self.memory_used_decode
self
.
process_input
=
lambda
image
:
image
self
.
process_output
=
lambda
image
:
image
else
:
#default SD1.x/SD2.x VAE parameters
ddconfig
=
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
}
...
...
@@ -206,12 +218,12 @@ class VAE:
steps
+=
samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
comfy
.
utils
.
ProgressBar
(
steps
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
vae_dtype
).
to
(
self
.
device
))
+
1.0
)
.
float
()
output
=
torch
.
clamp
(
(
decode_fn
=
lambda
a
:
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
output
=
self
.
process_output
(
(
comfy
.
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
self
.
downscale_ratio
,
output_device
=
self
.
output_device
,
pbar
=
pbar
)
+
comfy
.
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
self
.
downscale_ratio
,
output_device
=
self
.
output_device
,
pbar
=
pbar
)
+
comfy
.
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
self
.
downscale_ratio
,
output_device
=
self
.
output_device
,
pbar
=
pbar
))
/
3.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
/
3.0
)
return
output
def
encode_tiled_
(
self
,
pixel_samples
,
tile_x
=
512
,
tile_y
=
512
,
overlap
=
64
):
...
...
@@ -220,7 +232,7 @@ class VAE:
steps
+=
pixel_samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
comfy
.
utils
.
ProgressBar
(
steps
)
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
((
2.
*
a
-
1.
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
((
self
.
process_input
(
a
)
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
samples
=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
(
1
/
self
.
downscale_ratio
),
out_channels
=
self
.
latent_channels
,
output_device
=
self
.
output_device
,
pbar
=
pbar
)
samples
+=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
(
1
/
self
.
downscale_ratio
),
out_channels
=
self
.
latent_channels
,
output_device
=
self
.
output_device
,
pbar
=
pbar
)
samples
+=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
(
1
/
self
.
downscale_ratio
),
out_channels
=
self
.
latent_channels
,
output_device
=
self
.
output_device
,
pbar
=
pbar
)
...
...
@@ -238,7 +250,7 @@ class VAE:
pixel_samples
=
torch
.
empty
((
samples_in
.
shape
[
0
],
3
,
round
(
samples_in
.
shape
[
2
]
*
self
.
downscale_ratio
),
round
(
samples_in
.
shape
[
3
]
*
self
.
downscale_ratio
)),
device
=
self
.
output_device
)
for
x
in
range
(
0
,
samples_in
.
shape
[
0
],
batch_number
):
samples
=
samples_in
[
x
:
x
+
batch_number
].
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
pixel_samples
[
x
:
x
+
batch_number
]
=
torch
.
clamp
(
(
self
.
first_stage_model
.
decode
(
samples
).
to
(
self
.
output_device
).
float
()
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
)
pixel_samples
[
x
:
x
+
batch_number
]
=
self
.
process_output
(
self
.
first_stage_model
.
decode
(
samples
).
to
(
self
.
output_device
).
float
())
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
...
...
@@ -261,7 +273,7 @@ class VAE:
batch_number
=
max
(
1
,
batch_number
)
samples
=
torch
.
empty
((
pixel_samples
.
shape
[
0
],
self
.
latent_channels
,
round
(
pixel_samples
.
shape
[
2
]
//
self
.
downscale_ratio
),
round
(
pixel_samples
.
shape
[
3
]
//
self
.
downscale_ratio
)),
device
=
self
.
output_device
)
for
x
in
range
(
0
,
pixel_samples
.
shape
[
0
],
batch_number
):
pixels_in
=
(
2.
*
pixel_samples
[
x
:
x
+
batch_number
]
-
1.
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
pixels_in
=
self
.
process_input
(
pixel_samples
[
x
:
x
+
batch_number
]).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
to
(
self
.
output_device
).
float
()
except
model_management
.
OOM_EXCEPTION
as
e
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment