Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
8607c2d4
Commit
8607c2d4
authored
Jun 23, 2023
by
comfyanonymous
Browse files
Move latent scale factor from VAE to model.
parent
30a38619
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
73 additions
and
33 deletions
+73
-33
comfy/latent_formats.py
comfy/latent_formats.py
+16
-0
comfy/model_base.py
comfy/model_base.py
+18
-9
comfy/samplers.py
comfy/samplers.py
+4
-1
comfy/sd.py
comfy/sd.py
+19
-13
comfy/supported_models.py
comfy/supported_models.py
+7
-6
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-3
nodes.py
nodes.py
+5
-1
No files found.
comfy/latent_formats.py
0 → 100644
View file @
8607c2d4
class
LatentFormat
:
def
process_in
(
self
,
latent
):
return
latent
*
self
.
scale_factor
def
process_out
(
self
,
latent
):
return
latent
/
self
.
scale_factor
class
SD15
(
LatentFormat
):
def
__init__
(
self
,
scale_factor
=
0.18215
):
self
.
scale_factor
=
scale_factor
class
SDXL
(
LatentFormat
):
def
__init__
(
self
):
self
.
scale_factor
=
0.13025
comfy/model_base.py
View file @
8607c2d4
...
...
@@ -6,9 +6,11 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import
numpy
as
np
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
()
unet_config
=
model_config
.
unet_config
self
.
latent_format
=
model_config
.
latent_format
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
v_prediction
=
v_prediction
...
...
@@ -75,9 +77,16 @@ class BaseModel(torch.nn.Module):
del
to_load
return
self
def
process_latent_in
(
self
,
latent
):
return
self
.
latent_format
.
process_in
(
latent
)
def
process_latent_out
(
self
,
latent
):
return
self
.
latent_format
.
process_out
(
latent
)
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
noise_aug_config
,
v_prediction
=
True
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
noise_aug_config
,
v_prediction
=
True
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -112,13 +121,13 @@ class SD21UNCLIP(BaseModel):
return
adm_out
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
class
SDXLRefiner
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -144,8 +153,8 @@ class SDXLRefiner(BaseModel):
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SDXL
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
comfy/samplers.py
View file @
8607c2d4
...
...
@@ -586,6 +586,9 @@ class KSampler:
positive
=
encode_adm
(
self
.
model
,
positive
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
self
.
device
,
"positive"
)
negative
=
encode_adm
(
self
.
model
,
negative
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
self
.
device
,
"negative"
)
if
latent_image
is
not
None
:
latent_image
=
self
.
model
.
process_latent_in
(
latent_image
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
self
.
model_options
}
cond_concat
=
None
...
...
@@ -672,4 +675,4 @@ class KSampler:
else
:
samples
=
getattr
(
k_diffusion_sampling
,
"sample_{}"
.
format
(
self
.
sampler
))(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
return
samples
.
to
(
torch
.
float32
)
return
self
.
model
.
process_latent_out
(
samples
.
to
(
torch
.
float32
)
)
comfy/sd.py
View file @
8607c2d4
...
...
@@ -536,7 +536,7 @@ class CLIP:
class
VAE
:
def
__init__
(
self
,
ckpt_path
=
None
,
scale_factor
=
0.18215
,
device
=
None
,
config
=
None
):
def
__init__
(
self
,
ckpt_path
=
None
,
device
=
None
,
config
=
None
):
if
config
is
None
:
#default SD1.x/SD2.x VAE parameters
ddconfig
=
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
}
...
...
@@ -550,7 +550,6 @@ class VAE:
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
self
.
first_stage_model
.
load_state_dict
(
sd
,
strict
=
False
)
self
.
scale_factor
=
scale_factor
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
...
...
@@ -561,7 +560,7 @@ class VAE:
steps
+=
samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
utils
.
ProgressBar
(
steps
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
device
))
+
1.0
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
...
...
@@ -575,7 +574,7 @@ class VAE:
steps
+=
pixel_samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
utils
.
ProgressBar
(
steps
)
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
(
2.
*
a
.
to
(
self
.
device
)
-
1.
).
sample
()
*
self
.
scale_factor
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
(
2.
*
a
.
to
(
self
.
device
)
-
1.
).
sample
()
samples
=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
...
...
@@ -593,7 +592,7 @@ class VAE:
pixel_samples
=
torch
.
empty
((
samples_in
.
shape
[
0
],
3
,
round
(
samples_in
.
shape
[
2
]
*
8
),
round
(
samples_in
.
shape
[
3
]
*
8
)),
device
=
"cpu"
)
for
x
in
range
(
0
,
samples_in
.
shape
[
0
],
batch_number
):
samples
=
samples_in
[
x
:
x
+
batch_number
].
to
(
self
.
device
)
pixel_samples
[
x
:
x
+
batch_number
]
=
torch
.
clamp
((
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
samples
)
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
).
cpu
()
pixel_samples
[
x
:
x
+
batch_number
]
=
torch
.
clamp
((
self
.
first_stage_model
.
decode
(
samples
)
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
).
cpu
()
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
...
...
@@ -620,7 +619,7 @@ class VAE:
samples
=
torch
.
empty
((
pixel_samples
.
shape
[
0
],
4
,
round
(
pixel_samples
.
shape
[
2
]
//
8
),
round
(
pixel_samples
.
shape
[
3
]
//
8
)),
device
=
"cpu"
)
for
x
in
range
(
0
,
pixel_samples
.
shape
[
0
],
batch_number
):
pixels_in
=
(
2.
*
pixel_samples
[
x
:
x
+
batch_number
]
-
1.
).
to
(
self
.
device
)
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
sample
().
cpu
()
*
self
.
scale_factor
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
sample
().
cpu
()
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
...
...
@@ -958,6 +957,7 @@ def load_gligen(ckpt_path):
return
model
def
load_checkpoint
(
config_path
=
None
,
ckpt_path
=
None
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
None
,
state_dict
=
None
,
config
=
None
):
#TODO: this function is a mess and should be removed eventually
if
config
is
None
:
with
open
(
config_path
,
'r'
)
as
stream
:
config
=
yaml
.
safe_load
(
stream
)
...
...
@@ -992,12 +992,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if
state_dict
is
None
:
state_dict
=
utils
.
load_torch_file
(
ckpt_path
)
class
EmptyClass
:
pass
model_config
=
EmptyClass
()
model_config
.
unet_config
=
unet_config
from
.
import
latent_formats
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
model
=
model_base
.
SDInpaint
(
unet
_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
SDInpaint
(
model
_config
,
v_prediction
=
v_prediction
)
elif
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
unet
_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
model
=
model_base
.
SD21UNCLIP
(
model
_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
else
:
model
=
model_base
.
BaseModel
(
unet
_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
BaseModel
(
model
_config
,
v_prediction
=
v_prediction
)
if
fp16
:
model
=
model
.
half
()
...
...
@@ -1006,14 +1014,12 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if
output_vae
:
w
=
WeightsLoader
()
vae
=
VAE
(
scale_factor
=
scale_factor
,
config
=
vae_config
)
vae
=
VAE
(
config
=
vae_config
)
w
.
first_stage_model
=
vae
.
first_stage_model
load_model_weights
(
w
,
state_dict
)
if
output_clip
:
w
=
WeightsLoader
()
class
EmptyClass
:
pass
clip_target
=
EmptyClass
()
clip_target
.
params
=
clip_config
.
get
(
"params"
,
{})
if
clip_config
[
"target"
].
endswith
(
"FrozenOpenCLIPEmbedder"
):
...
...
@@ -1055,7 +1061,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
if
output_vae
:
vae
=
VAE
(
scale_factor
=
model_config
.
vae_scale_factor
)
vae
=
VAE
()
w
=
WeightsLoader
()
w
.
first_stage_model
=
vae
.
first_stage_model
load_model_weights
(
w
,
sd
)
...
...
comfy/supported_models.py
View file @
8607c2d4
...
...
@@ -7,6 +7,7 @@ from . import sd2_clip
from
.
import
sdxl_clip
from
.
import
supported_models_base
from
.
import
latent_formats
class
SD15
(
supported_models_base
.
BASE
):
unet_config
=
{
...
...
@@ -21,7 +22,7 @@ class SD15(supported_models_base.BASE):
"num_head_channels"
:
-
1
,
}
vae_scale_factor
=
0.182
15
latent_format
=
latent_formats
.
SD
15
def
process_clip_state_dict
(
self
,
state_dict
):
k
=
list
(
state_dict
.
keys
())
...
...
@@ -48,7 +49,7 @@ class SD20(supported_models_base.BASE):
"adm_in_channels"
:
None
,
}
vae_scale_factor
=
0.182
15
latent_format
=
latent_formats
.
SD
15
def
v_prediction
(
self
,
state_dict
):
if
self
.
unet_config
[
"in_channels"
]
==
4
:
#SD2.0 inpainting models are not v prediction
...
...
@@ -97,10 +98,10 @@ class SDXLRefiner(supported_models_base.BASE):
"transformer_depth"
:
[
0
,
4
,
4
,
0
],
}
vae_scale_factor
=
0.13025
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
):
return
model_base
.
SDXLRefiner
(
self
.
unet_config
)
return
model_base
.
SDXLRefiner
(
self
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
@@ -124,10 +125,10 @@ class SDXL(supported_models_base.BASE):
"adm_in_channels"
:
2816
}
vae_scale_factor
=
0.13025
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
):
return
model_base
.
SDXL
(
self
.
unet_config
)
return
model_base
.
SDXL
(
self
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
8607c2d4
...
...
@@ -49,16 +49,17 @@ class BASE:
def
__init__
(
self
,
unet_config
):
self
.
unet_config
=
unet_config
self
.
latent_format
=
self
.
latent_format
()
for
x
in
self
.
unet_extra_config
:
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
def
get_model
(
self
,
state_dict
):
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
.
unet_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
SDInpaint
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
.
unet_config
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
else
:
return
model_base
.
BaseModel
(
self
.
unet_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
BaseModel
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
...
...
nodes.py
View file @
8607c2d4
...
...
@@ -284,6 +284,7 @@ class SaveLatent:
output
=
{}
output
[
"latent_tensor"
]
=
samples
[
"samples"
]
output
[
"latent_format_version_0"
]
=
torch
.
tensor
([])
safetensors
.
torch
.
save_file
(
output
,
file
,
metadata
=
metadata
)
...
...
@@ -305,7 +306,10 @@ class LoadLatent:
def
load
(
self
,
latent
):
latent_path
=
folder_paths
.
get_annotated_filepath
(
latent
)
latent
=
safetensors
.
torch
.
load_file
(
latent_path
,
device
=
"cpu"
)
samples
=
{
"samples"
:
latent
[
"latent_tensor"
].
float
()}
multiplier
=
1.0
if
"latent_format_version_0"
not
in
latent
:
multiplier
=
1.0
/
0.18215
samples
=
{
"samples"
:
latent
[
"latent_tensor"
].
float
()
*
multiplier
}
return
(
samples
,
)
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment