Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
d44a2de4
Commit
d44a2de4
authored
Oct 17, 2023
by
comfyanonymous
Browse files
Make VAE code closer to sgm.
parent
f8caa24b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
235 additions
and
222 deletions
+235
-222
comfy/diffusers_load.py
comfy/diffusers_load.py
+2
-1
comfy/ldm/models/autoencoder.py
comfy/ldm/models/autoencoder.py
+187
-183
comfy/ldm/modules/diffusionmodules/model.py
comfy/ldm/modules/diffusionmodules/model.py
+17
-14
comfy/sd.py
comfy/sd.py
+19
-20
comfy/utils.py
comfy/utils.py
+8
-3
nodes.py
nodes.py
+2
-1
No files found.
comfy/diffusers_load.py
View file @
d44a2de4
...
@@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
...
@@ -31,6 +31,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
vae
=
None
vae
=
None
if
output_vae
:
if
output_vae
:
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
)
return
(
unet
,
clip
,
vae
)
return
(
unet
,
clip
,
vae
)
comfy/ldm/models/autoencoder.py
View file @
d44a2de4
...
@@ -2,67 +2,66 @@ import torch
...
@@ -2,67 +2,66 @@ import torch
# import pytorch_lightning as pl
# import pytorch_lightning as pl
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
comfy.ldm.modules.diffusionmodules.model
import
Encoder
,
Decoder
from
comfy.ldm.modules.distributions.distributions
import
DiagonalGaussianDistribution
from
comfy.ldm.modules.distributions.distributions
import
DiagonalGaussianDistribution
from
comfy.ldm.util
import
instantiate_from_config
from
comfy.ldm.util
import
instantiate_from_config
from
comfy.ldm.modules.ema
import
LitEma
from
comfy.ldm.modules.ema
import
LitEma
# class AutoencoderKL(pl.LightningModule):
class
DiagonalGaussianRegularizer
(
torch
.
nn
.
Module
):
class
AutoencoderKL
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
sample
:
bool
=
True
):
def
__init__
(
self
,
super
().
__init__
()
ddconfig
,
self
.
sample
=
sample
lossconfig
,
embed_dim
,
def
get_trainable_parameters
(
self
)
->
Any
:
ckpt_path
=
None
,
yield
from
()
ignore_keys
=
[],
image_key
=
"image"
,
def
forward
(
self
,
z
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
dict
]:
colorize_nlabels
=
None
,
log
=
dict
()
monitor
=
None
,
posterior
=
DiagonalGaussianDistribution
(
z
)
ema_decay
=
None
,
if
self
.
sample
:
learn_logvar
=
False
z
=
posterior
.
sample
()
else
:
z
=
posterior
.
mode
()
kl_loss
=
posterior
.
kl
()
kl_loss
=
torch
.
sum
(
kl_loss
)
/
kl_loss
.
shape
[
0
]
log
[
"kl_loss"
]
=
kl_loss
return
z
,
log
class
AbstractAutoencoder
(
torch
.
nn
.
Module
):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def
__init__
(
self
,
ema_decay
:
Union
[
None
,
float
]
=
None
,
monitor
:
Union
[
None
,
str
]
=
None
,
input_key
:
str
=
"jpg"
,
**
kwargs
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
learn_logvar
=
learn_logvar
self
.
image_key
=
image_key
self
.
input_key
=
input_key
self
.
encoder
=
Encoder
(
**
ddconfig
)
self
.
use_ema
=
ema_decay
is
not
None
self
.
decoder
=
Decoder
(
**
ddconfig
)
self
.
loss
=
instantiate_from_config
(
lossconfig
)
assert
ddconfig
[
"double_z"
]
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
2
*
ddconfig
[
"z_channels"
],
2
*
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
self
.
embed_dim
=
embed_dim
if
colorize_nlabels
is
not
None
:
assert
type
(
colorize_nlabels
)
==
int
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
colorize_nlabels
,
1
,
1
))
if
monitor
is
not
None
:
if
monitor
is
not
None
:
self
.
monitor
=
monitor
self
.
monitor
=
monitor
self
.
use_ema
=
ema_decay
is
not
None
if
self
.
use_ema
:
if
self
.
use_ema
:
self
.
ema_decay
=
ema_decay
assert
0.
<
ema_decay
<
1.
self
.
model_ema
=
LitEma
(
self
,
decay
=
ema_decay
)
self
.
model_ema
=
LitEma
(
self
,
decay
=
ema_decay
)
print
(
f
"Keeping EMAs of
{
len
(
list
(
self
.
model_ema
.
buffers
()))
}
."
)
logpy
.
info
(
f
"Keeping EMAs of
{
len
(
list
(
self
.
model_ema
.
buffers
()))
}
."
)
if
ckpt_path
is
not
None
:
def
get_input
(
self
,
batch
)
->
Any
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
)
raise
NotImplementedError
(
)
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
()):
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
):
if
path
.
lower
().
endswith
(
".safetensors"
):
# for EMA computation
import
safetensors.torch
if
self
.
use_ema
:
sd
=
safetensors
.
torch
.
load_file
(
path
,
device
=
"cpu"
)
self
.
model_ema
(
self
)
else
:
sd
=
torch
.
load
(
path
,
map_location
=
"cpu"
)[
"state_dict"
]
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
print
(
"Deleting key {} from state_dict."
.
format
(
k
))
del
sd
[
k
]
self
.
load_state_dict
(
sd
,
strict
=
False
)
print
(
f
"Restored from
{
path
}
"
)
@
contextmanager
@
contextmanager
def
ema_scope
(
self
,
context
=
None
):
def
ema_scope
(
self
,
context
=
None
):
...
@@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
...
@@ -70,154 +69,159 @@ class AutoencoderKL(torch.nn.Module):
self
.
model_ema
.
store
(
self
.
parameters
())
self
.
model_ema
.
store
(
self
.
parameters
())
self
.
model_ema
.
copy_to
(
self
)
self
.
model_ema
.
copy_to
(
self
)
if
context
is
not
None
:
if
context
is
not
None
:
print
(
f
"
{
context
}
: Switched to EMA weights"
)
logpy
.
info
(
f
"
{
context
}
: Switched to EMA weights"
)
try
:
try
:
yield
None
yield
None
finally
:
finally
:
if
self
.
use_ema
:
if
self
.
use_ema
:
self
.
model_ema
.
restore
(
self
.
parameters
())
self
.
model_ema
.
restore
(
self
.
parameters
())
if
context
is
not
None
:
if
context
is
not
None
:
print
(
f
"
{
context
}
: Restored training weights"
)
logpy
.
info
(
f
"
{
context
}
: Restored training weights"
)
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
):
def
encode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
if
self
.
use_ema
:
raise
NotImplementedError
(
"encode()-method of abstract base class called"
)
self
.
model_ema
(
self
)
def
decode
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
def
encode
(
self
,
x
):
raise
NotImplementedError
(
"decode()-method of abstract base class called"
)
h
=
self
.
encoder
(
x
)
moments
=
self
.
quant_conv
(
h
)
def
instantiate_optimizer_from_config
(
self
,
params
,
lr
,
cfg
):
posterior
=
DiagonalGaussianDistribution
(
moments
)
logpy
.
info
(
f
"loading >>>
{
cfg
[
'target'
]
}
<<< optimizer from config"
)
return
posterior
return
get_obj_from_str
(
cfg
[
"target"
])(
params
,
lr
=
lr
,
**
cfg
.
get
(
"params"
,
dict
())
def
decode
(
self
,
z
):
)
z
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
z
)
def
configure_optimizers
(
self
)
->
Any
:
return
dec
raise
NotImplementedError
()
def
forward
(
self
,
input
,
sample_posterior
=
True
):
posterior
=
self
.
encode
(
input
)
class
AutoencodingEngine
(
AbstractAutoencoder
):
if
sample_posterior
:
"""
z
=
posterior
.
sample
()
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
else
:
(we also restore them explicitly as special cases for legacy reasons).
z
=
posterior
.
mode
()
Regularizations such as KL or VQ are moved to the regularizer class.
dec
=
self
.
decode
(
z
)
"""
return
dec
,
posterior
def
__init__
(
def
get_input
(
self
,
batch
,
k
):
self
,
x
=
batch
[
k
]
*
args
,
if
len
(
x
.
shape
)
==
3
:
encoder_config
:
Dict
,
x
=
x
[...,
None
]
decoder_config
:
Dict
,
x
=
x
.
permute
(
0
,
3
,
1
,
2
).
to
(
memory_format
=
torch
.
contiguous_format
).
float
()
regularizer_config
:
Dict
,
return
x
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
training_step
(
self
,
batch
,
batch_idx
,
optimizer_idx
):
self
.
encoder
:
torch
.
nn
.
Module
=
instantiate_from_config
(
encoder_config
)
inputs
=
self
.
get_input
(
batch
,
self
.
image_key
)
self
.
decoder
:
torch
.
nn
.
Module
=
instantiate_from_config
(
decoder_config
)
reconstructions
,
posterior
=
self
(
inputs
)
self
.
regularization
:
AbstractRegularizer
=
instantiate_from_config
(
regularizer_config
if
optimizer_idx
==
0
:
)
# train encoder+decoder+logvar
aeloss
,
log_dict_ae
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
)
self
.
log
(
"aeloss"
,
aeloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
self
.
log_dict
(
log_dict_ae
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
return
aeloss
if
optimizer_idx
==
1
:
# train the discriminator
discloss
,
log_dict_disc
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
)
self
.
log
(
"discloss"
,
discloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
self
.
log_dict
(
log_dict_disc
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
return
discloss
def
validation_step
(
self
,
batch
,
batch_idx
):
log_dict
=
self
.
_validation_step
(
batch
,
batch_idx
)
with
self
.
ema_scope
():
log_dict_ema
=
self
.
_validation_step
(
batch
,
batch_idx
,
postfix
=
"_ema"
)
return
log_dict
def
_validation_step
(
self
,
batch
,
batch_idx
,
postfix
=
""
):
inputs
=
self
.
get_input
(
batch
,
self
.
image_key
)
reconstructions
,
posterior
=
self
(
inputs
)
aeloss
,
log_dict_ae
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
0
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
+
postfix
)
discloss
,
log_dict_disc
=
self
.
loss
(
inputs
,
reconstructions
,
posterior
,
1
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
+
postfix
)
self
.
log
(
f
"val
{
postfix
}
/rec_loss"
,
log_dict_ae
[
f
"val
{
postfix
}
/rec_loss"
])
self
.
log_dict
(
log_dict_ae
)
self
.
log_dict
(
log_dict_disc
)
return
self
.
log_dict
def
configure_optimizers
(
self
):
lr
=
self
.
learning_rate
ae_params_list
=
list
(
self
.
encoder
.
parameters
())
+
list
(
self
.
decoder
.
parameters
())
+
list
(
self
.
quant_conv
.
parameters
())
+
list
(
self
.
post_quant_conv
.
parameters
())
if
self
.
learn_logvar
:
print
(
f
"
{
self
.
__class__
.
__name__
}
: Learning logvar"
)
ae_params_list
.
append
(
self
.
loss
.
logvar
)
opt_ae
=
torch
.
optim
.
Adam
(
ae_params_list
,
lr
=
lr
,
betas
=
(
0.5
,
0.9
))
opt_disc
=
torch
.
optim
.
Adam
(
self
.
loss
.
discriminator
.
parameters
(),
lr
=
lr
,
betas
=
(
0.5
,
0.9
))
return
[
opt_ae
,
opt_disc
],
[]
def
get_last_layer
(
self
):
def
get_last_layer
(
self
):
return
self
.
decoder
.
conv_out
.
weight
return
self
.
decoder
.
get_last_layer
()
@
torch
.
no_grad
()
def
encode
(
def
log_images
(
self
,
batch
,
only_inputs
=
False
,
log_ema
=
False
,
**
kwargs
):
self
,
log
=
dict
()
x
:
torch
.
Tensor
,
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
return_reg_log
:
bool
=
False
,
x
=
x
.
to
(
self
.
device
)
unregularized
:
bool
=
False
,
if
not
only_inputs
:
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
dict
]]:
xrec
,
posterior
=
self
(
x
)
z
=
self
.
encoder
(
x
)
if
x
.
shape
[
1
]
>
3
:
if
unregularized
:
# colorize with random projection
return
z
,
dict
()
assert
xrec
.
shape
[
1
]
>
3
z
,
reg_log
=
self
.
regularization
(
z
)
x
=
self
.
to_rgb
(
x
)
if
return_reg_log
:
xrec
=
self
.
to_rgb
(
xrec
)
return
z
,
reg_log
log
[
"samples"
]
=
self
.
decode
(
torch
.
randn_like
(
posterior
.
sample
()))
return
z
log
[
"reconstructions"
]
=
xrec
if
log_ema
or
self
.
use_ema
:
def
decode
(
self
,
z
:
torch
.
Tensor
,
**
kwargs
)
->
torch
.
Tensor
:
with
self
.
ema_scope
():
x
=
self
.
decoder
(
z
,
**
kwargs
)
xrec_ema
,
posterior_ema
=
self
(
x
)
if
x
.
shape
[
1
]
>
3
:
# colorize with random projection
assert
xrec_ema
.
shape
[
1
]
>
3
xrec_ema
=
self
.
to_rgb
(
xrec_ema
)
log
[
"samples_ema"
]
=
self
.
decode
(
torch
.
randn_like
(
posterior_ema
.
sample
()))
log
[
"reconstructions_ema"
]
=
xrec_ema
log
[
"inputs"
]
=
x
return
log
def
to_rgb
(
self
,
x
):
assert
self
.
image_key
==
"segmentation"
if
not
hasattr
(
self
,
"colorize"
):
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
x
.
shape
[
1
],
1
,
1
).
to
(
x
))
x
=
F
.
conv2d
(
x
,
weight
=
self
.
colorize
)
x
=
2.
*
(
x
-
x
.
min
())
/
(
x
.
max
()
-
x
.
min
())
-
1.
return
x
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
,
**
additional_decode_kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
dict
]:
z
,
reg_log
=
self
.
encode
(
x
,
return_reg_log
=
True
)
dec
=
self
.
decode
(
z
,
**
additional_decode_kwargs
)
return
z
,
dec
,
reg_log
class
AutoencodingEngineLegacy
(
AutoencodingEngine
):
def
__init__
(
self
,
embed_dim
:
int
,
**
kwargs
):
self
.
max_batch_size
=
kwargs
.
pop
(
"max_batch_size"
,
None
)
ddconfig
=
kwargs
.
pop
(
"ddconfig"
)
super
().
__init__
(
encoder_config
=
{
"target"
:
"comfy.ldm.modules.diffusionmodules.model.Encoder"
,
"params"
:
ddconfig
,
},
decoder_config
=
{
"target"
:
"comfy.ldm.modules.diffusionmodules.model.Decoder"
,
"params"
:
ddconfig
,
},
**
kwargs
,
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
(
1
+
ddconfig
[
"double_z"
])
*
ddconfig
[
"z_channels"
],
(
1
+
ddconfig
[
"double_z"
])
*
embed_dim
,
1
,
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
self
.
embed_dim
=
embed_dim
class
IdentityFirstStage
(
torch
.
nn
.
Module
):
def
get_autoencoder_params
(
self
)
->
list
:
def
__init__
(
self
,
*
args
,
vq_interface
=
False
,
**
kwargs
):
params
=
super
().
get_autoencoder_params
()
self
.
vq_interface
=
vq_interface
return
params
super
().
__init__
()
def
encode
(
self
,
x
,
*
args
,
**
kwargs
):
return
x
def
decode
(
self
,
x
,
*
args
,
**
kwargs
):
def
encode
(
return
x
self
,
x
:
torch
.
Tensor
,
return_reg_log
:
bool
=
False
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
dict
]]:
if
self
.
max_batch_size
is
None
:
z
=
self
.
encoder
(
x
)
z
=
self
.
quant_conv
(
z
)
else
:
N
=
x
.
shape
[
0
]
bs
=
self
.
max_batch_size
n_batches
=
int
(
math
.
ceil
(
N
/
bs
))
z
=
list
()
for
i_batch
in
range
(
n_batches
):
z_batch
=
self
.
encoder
(
x
[
i_batch
*
bs
:
(
i_batch
+
1
)
*
bs
])
z_batch
=
self
.
quant_conv
(
z_batch
)
z
.
append
(
z_batch
)
z
=
torch
.
cat
(
z
,
0
)
z
,
reg_log
=
self
.
regularization
(
z
)
if
return_reg_log
:
return
z
,
reg_log
return
z
def
decode
(
self
,
z
:
torch
.
Tensor
,
**
decoder_kwargs
)
->
torch
.
Tensor
:
if
self
.
max_batch_size
is
None
:
dec
=
self
.
post_quant_conv
(
z
)
dec
=
self
.
decoder
(
dec
,
**
decoder_kwargs
)
else
:
N
=
z
.
shape
[
0
]
bs
=
self
.
max_batch_size
n_batches
=
int
(
math
.
ceil
(
N
/
bs
))
dec
=
list
()
for
i_batch
in
range
(
n_batches
):
dec_batch
=
self
.
post_quant_conv
(
z
[
i_batch
*
bs
:
(
i_batch
+
1
)
*
bs
])
dec_batch
=
self
.
decoder
(
dec_batch
,
**
decoder_kwargs
)
dec
.
append
(
dec_batch
)
dec
=
torch
.
cat
(
dec
,
0
)
def
quantize
(
self
,
x
,
*
args
,
**
kwargs
):
return
dec
if
self
.
vq_interface
:
return
x
,
None
,
[
None
,
None
,
None
]
return
x
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
return
x
class
AutoencoderKL
(
AutoencodingEngineLegacy
):
def
__init__
(
self
,
**
kwargs
):
if
"lossconfig"
in
kwargs
:
kwargs
[
"loss_config"
]
=
kwargs
.
pop
(
"lossconfig"
)
super
().
__init__
(
regularizer_config
=
{
"target"
:
(
"comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"
)
},
**
kwargs
,
)
comfy/ldm/modules/diffusionmodules/model.py
View file @
d44a2de4
...
@@ -541,7 +541,10 @@ class Decoder(nn.Module):
...
@@ -541,7 +541,10 @@ class Decoder(nn.Module):
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
def
__init__
(
self
,
*
,
ch
,
out_ch
,
ch_mult
=
(
1
,
2
,
4
,
8
),
num_res_blocks
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
attn_resolutions
,
dropout
=
0.0
,
resamp_with_conv
=
True
,
in_channels
,
resolution
,
z_channels
,
give_pre_end
=
False
,
tanh_out
=
False
,
use_linear_attn
=
False
,
resolution
,
z_channels
,
give_pre_end
=
False
,
tanh_out
=
False
,
use_linear_attn
=
False
,
attn_type
=
"vanilla"
,
**
ignorekwargs
):
conv_out_op
=
comfy
.
ops
.
Conv2d
,
resnet_op
=
ResnetBlock
,
attn_op
=
AttnBlock
,
**
ignorekwargs
):
super
().
__init__
()
super
().
__init__
()
if
use_linear_attn
:
attn_type
=
"linear"
if
use_linear_attn
:
attn_type
=
"linear"
self
.
ch
=
ch
self
.
ch
=
ch
...
@@ -570,12 +573,12 @@ class Decoder(nn.Module):
...
@@ -570,12 +573,12 @@ class Decoder(nn.Module):
# middle
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
R
esnet
Block
(
in_channels
=
block_in
,
self
.
mid
.
block_1
=
r
esnet
_op
(
in_channels
=
block_in
,
out_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
dropout
=
dropout
)
self
.
mid
.
attn_1
=
make_
attn
(
block_in
,
attn_type
=
attn_type
)
self
.
mid
.
attn_1
=
attn
_op
(
block_in
)
self
.
mid
.
block_2
=
R
esnet
Block
(
in_channels
=
block_in
,
self
.
mid
.
block_2
=
r
esnet
_op
(
in_channels
=
block_in
,
out_channels
=
block_in
,
out_channels
=
block_in
,
temb_channels
=
self
.
temb_ch
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
)
dropout
=
dropout
)
...
@@ -587,13 +590,13 @@ class Decoder(nn.Module):
...
@@ -587,13 +590,13 @@ class Decoder(nn.Module):
attn
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
R
esnet
Block
(
in_channels
=
block_in
,
block
.
append
(
r
esnet
_op
(
in_channels
=
block_in
,
out_channels
=
block_out
,
out_channels
=
block_out
,
temb_channels
=
self
.
temb_ch
,
temb_channels
=
self
.
temb_ch
,
dropout
=
dropout
))
dropout
=
dropout
))
block_in
=
block_out
block_in
=
block_out
if
curr_res
in
attn_resolutions
:
if
curr_res
in
attn_resolutions
:
attn
.
append
(
make_
attn
(
block_in
,
attn_type
=
attn_type
))
attn
.
append
(
attn
_op
(
block_in
))
up
=
nn
.
Module
()
up
=
nn
.
Module
()
up
.
block
=
block
up
.
block
=
block
up
.
attn
=
attn
up
.
attn
=
attn
...
@@ -604,13 +607,13 @@ class Decoder(nn.Module):
...
@@ -604,13 +607,13 @@ class Decoder(nn.Module):
# end
# end
self
.
norm_out
=
Normalize
(
block_in
)
self
.
norm_out
=
Normalize
(
block_in
)
self
.
conv_out
=
co
mfy
.
ops
.
Conv2d
(
block_in
,
self
.
conv_out
=
co
nv_out_op
(
block_in
,
out_ch
,
out_ch
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
)
padding
=
1
)
def
forward
(
self
,
z
):
def
forward
(
self
,
z
,
**
kwargs
):
#assert z.shape[1:] == self.z_shape[1:]
#assert z.shape[1:] == self.z_shape[1:]
self
.
last_z_shape
=
z
.
shape
self
.
last_z_shape
=
z
.
shape
...
@@ -621,16 +624,16 @@ class Decoder(nn.Module):
...
@@ -621,16 +624,16 @@ class Decoder(nn.Module):
h
=
self
.
conv_in
(
z
)
h
=
self
.
conv_in
(
z
)
# middle
# middle
h
=
self
.
mid
.
block_1
(
h
,
temb
)
h
=
self
.
mid
.
block_1
(
h
,
temb
,
**
kwargs
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
attn_1
(
h
,
**
kwargs
)
h
=
self
.
mid
.
block_2
(
h
,
temb
)
h
=
self
.
mid
.
block_2
(
h
,
temb
,
**
kwargs
)
# upsampling
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
)
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
,
temb
,
**
kwargs
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
,
**
kwargs
)
if
i_level
!=
0
:
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
h
=
self
.
up
[
i_level
].
upsample
(
h
)
...
@@ -640,7 +643,7 @@ class Decoder(nn.Module):
...
@@ -640,7 +643,7 @@ class Decoder(nn.Module):
h
=
self
.
norm_out
(
h
)
h
=
self
.
norm_out
(
h
)
h
=
nonlinearity
(
h
)
h
=
nonlinearity
(
h
)
h
=
self
.
conv_out
(
h
)
h
=
self
.
conv_out
(
h
,
**
kwargs
)
if
self
.
tanh_out
:
if
self
.
tanh_out
:
h
=
torch
.
tanh
(
h
)
h
=
torch
.
tanh
(
h
)
return
h
return
h
comfy/sd.py
View file @
d44a2de4
...
@@ -4,7 +4,7 @@ import math
...
@@ -4,7 +4,7 @@ import math
from
comfy
import
model_management
from
comfy
import
model_management
from
.ldm.util
import
instantiate_from_config
from
.ldm.util
import
instantiate_from_config
from
.ldm.models.autoencoder
import
AutoencoderKL
from
.ldm.models.autoencoder
import
AutoencoderKL
,
AutoencodingEngine
import
yaml
import
yaml
import
comfy.utils
import
comfy.utils
...
@@ -140,22 +140,25 @@ class CLIP:
...
@@ -140,22 +140,25 @@ class CLIP:
return
self
.
patcher
.
get_key_patches
()
return
self
.
patcher
.
get_key_patches
()
class
VAE
:
class
VAE
:
def
__init__
(
self
,
ckpt_path
=
None
,
device
=
None
,
config
=
None
):
def
__init__
(
self
,
sd
=
None
,
device
=
None
,
config
=
None
):
if
'decoder.up_blocks.0.resnets.0.norm1.weight'
in
sd
.
keys
():
#diffusers format
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
if
config
is
None
:
if
config
is
None
:
#default SD1.x/SD2.x VAE parameters
#default SD1.x/SD2.x VAE parameters
ddconfig
=
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
}
ddconfig
=
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
}
self
.
first_stage_model
=
AutoencoderKL
(
ddconfig
,
{
'target'
:
'torch.nn.Identity'
},
4
,
monitor
=
"val/rec_loss"
)
self
.
first_stage_model
=
AutoencoderKL
(
ddconfig
=
ddconfig
,
embed_dim
=
4
)
else
:
else
:
self
.
first_stage_model
=
AutoencoderKL
(
**
(
config
[
'params'
]))
self
.
first_stage_model
=
AutoencoderKL
(
**
(
config
[
'params'
]))
self
.
first_stage_model
=
self
.
first_stage_model
.
eval
()
self
.
first_stage_model
=
self
.
first_stage_model
.
eval
()
if
ckpt_path
is
not
None
:
sd
=
comfy
.
utils
.
load_torch_file
(
ckpt_path
)
if
'decoder.up_blocks.0.resnets.0.norm1.weight'
in
sd
.
keys
():
#diffusers format
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
m
,
u
=
self
.
first_stage_model
.
load_state_dict
(
sd
,
strict
=
False
)
m
,
u
=
self
.
first_stage_model
.
load_state_dict
(
sd
,
strict
=
False
)
if
len
(
m
)
>
0
:
if
len
(
m
)
>
0
:
print
(
"Missing VAE keys"
,
m
)
print
(
"Missing VAE keys"
,
m
)
if
len
(
u
)
>
0
:
print
(
"Leftover VAE keys"
,
u
)
if
device
is
None
:
if
device
is
None
:
device
=
model_management
.
vae_device
()
device
=
model_management
.
vae_device
()
self
.
device
=
device
self
.
device
=
device
...
@@ -183,7 +186,7 @@ class VAE:
...
@@ -183,7 +186,7 @@ class VAE:
steps
+=
pixel_samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
steps
+=
pixel_samples
.
shape
[
0
]
*
comfy
.
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
comfy
.
utils
.
ProgressBar
(
steps
)
pbar
=
comfy
.
utils
.
ProgressBar
(
steps
)
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
((
2.
*
a
-
1.
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
sample
().
float
()
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
((
2.
*
a
-
1.
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)).
float
()
samples
=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
comfy
.
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
...
@@ -229,7 +232,7 @@ class VAE:
...
@@ -229,7 +232,7 @@ class VAE:
samples
=
torch
.
empty
((
pixel_samples
.
shape
[
0
],
4
,
round
(
pixel_samples
.
shape
[
2
]
//
8
),
round
(
pixel_samples
.
shape
[
3
]
//
8
)),
device
=
"cpu"
)
samples
=
torch
.
empty
((
pixel_samples
.
shape
[
0
],
4
,
round
(
pixel_samples
.
shape
[
2
]
//
8
),
round
(
pixel_samples
.
shape
[
3
]
//
8
)),
device
=
"cpu"
)
for
x
in
range
(
0
,
pixel_samples
.
shape
[
0
],
batch_number
):
for
x
in
range
(
0
,
pixel_samples
.
shape
[
0
],
batch_number
):
pixels_in
=
(
2.
*
pixel_samples
[
x
:
x
+
batch_number
]
-
1.
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
pixels_in
=
(
2.
*
pixel_samples
[
x
:
x
+
batch_number
]
-
1.
).
to
(
self
.
vae_dtype
).
to
(
self
.
device
)
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
sample
().
cpu
().
float
()
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
cpu
().
float
()
except
model_management
.
OOM_EXCEPTION
as
e
:
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
...
@@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
...
@@ -375,10 +378,8 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
model
.
load_model_weights
(
state_dict
,
"model.diffusion_model."
)
model
.
load_model_weights
(
state_dict
,
"model.diffusion_model."
)
if
output_vae
:
if
output_vae
:
w
=
WeightsLoader
()
vae_sd
=
comfy
.
utils
.
state_dict_prefix_replace
(
state_dict
,
{
"first_stage_model."
:
""
},
filter_keys
=
True
)
vae
=
VAE
(
config
=
vae_config
)
vae
=
VAE
(
sd
=
vae_sd
,
config
=
vae_config
)
w
.
first_stage_model
=
vae
.
first_stage_model
load_model_weights
(
w
,
state_dict
)
if
output_clip
:
if
output_clip
:
w
=
WeightsLoader
()
w
=
WeightsLoader
()
...
@@ -427,10 +428,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
...
@@ -427,10 +428,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
if
output_vae
:
if
output_vae
:
vae
=
VAE
()
vae_sd
=
comfy
.
utils
.
state_dict_prefix_replace
(
sd
,
{
"first_stage_model."
:
""
},
filter_keys
=
True
)
w
=
WeightsLoader
()
vae
=
VAE
(
sd
=
vae_sd
)
w
.
first_stage_model
=
vae
.
first_stage_model
load_model_weights
(
w
,
sd
)
if
output_clip
:
if
output_clip
:
w
=
WeightsLoader
()
w
=
WeightsLoader
()
...
...
comfy/utils.py
View file @
d44a2de4
...
@@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
...
@@ -47,12 +47,17 @@ def state_dict_key_replace(state_dict, keys_to_replace):
state_dict
[
keys_to_replace
[
x
]]
=
state_dict
.
pop
(
x
)
state_dict
[
keys_to_replace
[
x
]]
=
state_dict
.
pop
(
x
)
return
state_dict
return
state_dict
def
state_dict_prefix_replace
(
state_dict
,
replace_prefix
):
def
state_dict_prefix_replace
(
state_dict
,
replace_prefix
,
filter_keys
=
False
):
if
filter_keys
:
out
=
{}
else
:
out
=
state_dict
for
rp
in
replace_prefix
:
for
rp
in
replace_prefix
:
replace
=
list
(
map
(
lambda
a
:
(
a
,
"{}{}"
.
format
(
replace_prefix
[
rp
],
a
[
len
(
rp
):])),
filter
(
lambda
a
:
a
.
startswith
(
rp
),
state_dict
.
keys
())))
replace
=
list
(
map
(
lambda
a
:
(
a
,
"{}{}"
.
format
(
replace_prefix
[
rp
],
a
[
len
(
rp
):])),
filter
(
lambda
a
:
a
.
startswith
(
rp
),
state_dict
.
keys
())))
for
x
in
replace
:
for
x
in
replace
:
state_dict
[
x
[
1
]]
=
state_dict
.
pop
(
x
[
0
])
w
=
state_dict
.
pop
(
x
[
0
])
return
state_dict
out
[
x
[
1
]]
=
w
return
out
def
transformers_convert
(
sd
,
prefix_from
,
prefix_to
,
number
):
def
transformers_convert
(
sd
,
prefix_from
,
prefix_to
,
number
):
...
...
nodes.py
View file @
d44a2de4
...
@@ -584,7 +584,8 @@ class VAELoader:
...
@@ -584,7 +584,8 @@ class VAELoader:
#TODO: scale factor?
#TODO: scale factor?
def
load_vae
(
self
,
vae_name
):
def
load_vae
(
self
,
vae_name
):
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
vae_path
=
folder_paths
.
get_full_path
(
"vae"
,
vae_name
)
vae
=
comfy
.
sd
.
VAE
(
ckpt_path
=
vae_path
)
sd
=
comfy
.
utils
.
load_torch_file
(
vae_path
)
vae
=
comfy
.
sd
.
VAE
(
sd
=
sd
)
return
(
vae
,)
return
(
vae
,)
class
ControlNetLoader
:
class
ControlNetLoader
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment