Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
8607c2d4
"docs/bash_helper_functions" did not exist on "86614971ced707ba9bac18b0cfa99b17cf621774"
Commit
8607c2d4
authored
Jun 23, 2023
by
comfyanonymous
Browse files
Move latent scale factor from VAE to model.
parent
30a38619
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
73 additions
and
33 deletions
+73
-33
comfy/latent_formats.py
comfy/latent_formats.py
+16
-0
comfy/model_base.py
comfy/model_base.py
+18
-9
comfy/samplers.py
comfy/samplers.py
+4
-1
comfy/sd.py
comfy/sd.py
+19
-13
comfy/supported_models.py
comfy/supported_models.py
+7
-6
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-3
nodes.py
nodes.py
+5
-1
No files found.
comfy/latent_formats.py
0 → 100644
View file @
8607c2d4
class
LatentFormat
:
def
process_in
(
self
,
latent
):
return
latent
*
self
.
scale_factor
def
process_out
(
self
,
latent
):
return
latent
/
self
.
scale_factor
class
SD15
(
LatentFormat
):
def
__init__
(
self
,
scale_factor
=
0.18215
):
self
.
scale_factor
=
scale_factor
class
SDXL
(
LatentFormat
):
def
__init__
(
self
):
self
.
scale_factor
=
0.13025
comfy/model_base.py
View file @
8607c2d4
...
...
@@ -6,9 +6,11 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import
numpy
as
np
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
()
unet_config
=
model_config
.
unet_config
self
.
latent_format
=
model_config
.
latent_format
self
.
register_schedule
(
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
0.00085
,
linear_end
=
0.012
,
cosine_s
=
8e-3
)
self
.
diffusion_model
=
UNetModel
(
**
unet_config
)
self
.
v_prediction
=
v_prediction
...
...
@@ -75,9 +77,16 @@ class BaseModel(torch.nn.Module):
del
to_load
return
self
def
process_latent_in
(
self
,
latent
):
return
self
.
latent_format
.
process_in
(
latent
)
def
process_latent_out
(
self
,
latent
):
return
self
.
latent_format
.
process_out
(
latent
)
class
SD21UNCLIP
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
noise_aug_config
,
v_prediction
=
True
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
noise_aug_config
,
v_prediction
=
True
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
noise_augmentor
=
CLIPEmbeddingNoiseAugmentation
(
**
noise_aug_config
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -112,13 +121,13 @@ class SD21UNCLIP(BaseModel):
return
adm_out
class
SDInpaint
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
concat_keys
=
(
"mask"
,
"masked_image"
)
class
SDXLRefiner
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
@@ -144,8 +153,8 @@ class SDXLRefiner(BaseModel):
return
torch
.
cat
((
clip_pooled
.
to
(
flat
.
device
),
flat
),
dim
=
1
)
class
SDXL
(
BaseModel
):
def
__init__
(
self
,
unet
_config
,
v_prediction
=
False
):
super
().
__init__
(
unet
_config
,
v_prediction
)
def
__init__
(
self
,
model
_config
,
v_prediction
=
False
):
super
().
__init__
(
model
_config
,
v_prediction
)
self
.
embedder
=
Timestep
(
256
)
def
encode_adm
(
self
,
**
kwargs
):
...
...
comfy/samplers.py
View file @
8607c2d4
...
...
@@ -586,6 +586,9 @@ class KSampler:
positive
=
encode_adm
(
self
.
model
,
positive
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
self
.
device
,
"positive"
)
negative
=
encode_adm
(
self
.
model
,
negative
,
noise
.
shape
[
0
],
noise
.
shape
[
3
],
noise
.
shape
[
2
],
self
.
device
,
"negative"
)
if
latent_image
is
not
None
:
latent_image
=
self
.
model
.
process_latent_in
(
latent_image
)
extra_args
=
{
"cond"
:
positive
,
"uncond"
:
negative
,
"cond_scale"
:
cfg
,
"model_options"
:
self
.
model_options
}
cond_concat
=
None
...
...
@@ -672,4 +675,4 @@ class KSampler:
else
:
samples
=
getattr
(
k_diffusion_sampling
,
"sample_{}"
.
format
(
self
.
sampler
))(
self
.
model_k
,
noise
,
sigmas
,
extra_args
=
extra_args
,
callback
=
k_callback
,
disable
=
disable_pbar
)
return
samples
.
to
(
torch
.
float32
)
return
self
.
model
.
process_latent_out
(
samples
.
to
(
torch
.
float32
)
)
comfy/sd.py
View file @
8607c2d4
...
...
@@ -536,7 +536,7 @@ class CLIP:
class
VAE
:
def
__init__
(
self
,
ckpt_path
=
None
,
scale_factor
=
0.18215
,
device
=
None
,
config
=
None
):
def
__init__
(
self
,
ckpt_path
=
None
,
device
=
None
,
config
=
None
):
if
config
is
None
:
#default SD1.x/SD2.x VAE parameters
ddconfig
=
{
'double_z'
:
True
,
'z_channels'
:
4
,
'resolution'
:
256
,
'in_channels'
:
3
,
'out_ch'
:
3
,
'ch'
:
128
,
'ch_mult'
:
[
1
,
2
,
4
,
4
],
'num_res_blocks'
:
2
,
'attn_resolutions'
:
[],
'dropout'
:
0.0
}
...
...
@@ -550,7 +550,6 @@ class VAE:
sd
=
diffusers_convert
.
convert_vae_state_dict
(
sd
)
self
.
first_stage_model
.
load_state_dict
(
sd
,
strict
=
False
)
self
.
scale_factor
=
scale_factor
if
device
is
None
:
device
=
model_management
.
get_torch_device
()
self
.
device
=
device
...
...
@@ -561,7 +560,7 @@ class VAE:
steps
+=
samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
samples
.
shape
[
3
],
samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
utils
.
ProgressBar
(
steps
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
a
.
to
(
self
.
device
))
+
1.0
)
decode_fn
=
lambda
a
:
(
self
.
first_stage_model
.
decode
(
a
.
to
(
self
.
device
))
+
1.0
)
output
=
torch
.
clamp
((
(
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
utils
.
tiled_scale
(
samples
,
decode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
8
,
pbar
=
pbar
)
+
...
...
@@ -575,7 +574,7 @@ class VAE:
steps
+=
pixel_samples
.
shape
[
0
]
*
utils
.
get_tiled_scale_steps
(
pixel_samples
.
shape
[
3
],
pixel_samples
.
shape
[
2
],
tile_x
*
2
,
tile_y
//
2
,
overlap
)
pbar
=
utils
.
ProgressBar
(
steps
)
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
(
2.
*
a
.
to
(
self
.
device
)
-
1.
).
sample
()
*
self
.
scale_factor
encode_fn
=
lambda
a
:
self
.
first_stage_model
.
encode
(
2.
*
a
.
to
(
self
.
device
)
-
1.
).
sample
()
samples
=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
,
tile_y
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
*
2
,
tile_y
//
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
samples
+=
utils
.
tiled_scale
(
pixel_samples
,
encode_fn
,
tile_x
//
2
,
tile_y
*
2
,
overlap
,
upscale_amount
=
(
1
/
8
),
out_channels
=
4
,
pbar
=
pbar
)
...
...
@@ -593,7 +592,7 @@ class VAE:
pixel_samples
=
torch
.
empty
((
samples_in
.
shape
[
0
],
3
,
round
(
samples_in
.
shape
[
2
]
*
8
),
round
(
samples_in
.
shape
[
3
]
*
8
)),
device
=
"cpu"
)
for
x
in
range
(
0
,
samples_in
.
shape
[
0
],
batch_number
):
samples
=
samples_in
[
x
:
x
+
batch_number
].
to
(
self
.
device
)
pixel_samples
[
x
:
x
+
batch_number
]
=
torch
.
clamp
((
self
.
first_stage_model
.
decode
(
1.
/
self
.
scale_factor
*
samples
)
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
).
cpu
()
pixel_samples
[
x
:
x
+
batch_number
]
=
torch
.
clamp
((
self
.
first_stage_model
.
decode
(
samples
)
+
1.0
)
/
2.0
,
min
=
0.0
,
max
=
1.0
).
cpu
()
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding."
)
pixel_samples
=
self
.
decode_tiled_
(
samples_in
)
...
...
@@ -620,7 +619,7 @@ class VAE:
samples
=
torch
.
empty
((
pixel_samples
.
shape
[
0
],
4
,
round
(
pixel_samples
.
shape
[
2
]
//
8
),
round
(
pixel_samples
.
shape
[
3
]
//
8
)),
device
=
"cpu"
)
for
x
in
range
(
0
,
pixel_samples
.
shape
[
0
],
batch_number
):
pixels_in
=
(
2.
*
pixel_samples
[
x
:
x
+
batch_number
]
-
1.
).
to
(
self
.
device
)
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
sample
().
cpu
()
*
self
.
scale_factor
samples
[
x
:
x
+
batch_number
]
=
self
.
first_stage_model
.
encode
(
pixels_in
).
sample
().
cpu
()
except
model_management
.
OOM_EXCEPTION
as
e
:
print
(
"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding."
)
...
...
@@ -958,6 +957,7 @@ def load_gligen(ckpt_path):
return
model
def
load_checkpoint
(
config_path
=
None
,
ckpt_path
=
None
,
output_vae
=
True
,
output_clip
=
True
,
embedding_directory
=
None
,
state_dict
=
None
,
config
=
None
):
#TODO: this function is a mess and should be removed eventually
if
config
is
None
:
with
open
(
config_path
,
'r'
)
as
stream
:
config
=
yaml
.
safe_load
(
stream
)
...
...
@@ -992,12 +992,20 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if
state_dict
is
None
:
state_dict
=
utils
.
load_torch_file
(
ckpt_path
)
class
EmptyClass
:
pass
model_config
=
EmptyClass
()
model_config
.
unet_config
=
unet_config
from
.
import
latent_formats
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
if
config
[
'model'
][
"target"
].
endswith
(
"LatentInpaintDiffusion"
):
model
=
model_base
.
SDInpaint
(
unet
_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
SDInpaint
(
model
_config
,
v_prediction
=
v_prediction
)
elif
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
unet
_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
model
=
model_base
.
SD21UNCLIP
(
model
_config
,
noise_aug_config
[
"params"
],
v_prediction
=
v_prediction
)
else
:
model
=
model_base
.
BaseModel
(
unet
_config
,
v_prediction
=
v_prediction
)
model
=
model_base
.
BaseModel
(
model
_config
,
v_prediction
=
v_prediction
)
if
fp16
:
model
=
model
.
half
()
...
...
@@ -1006,14 +1014,12 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
if
output_vae
:
w
=
WeightsLoader
()
vae
=
VAE
(
scale_factor
=
scale_factor
,
config
=
vae_config
)
vae
=
VAE
(
config
=
vae_config
)
w
.
first_stage_model
=
vae
.
first_stage_model
load_model_weights
(
w
,
state_dict
)
if
output_clip
:
w
=
WeightsLoader
()
class
EmptyClass
:
pass
clip_target
=
EmptyClass
()
clip_target
.
params
=
clip_config
.
get
(
"params"
,
{})
if
clip_config
[
"target"
].
endswith
(
"FrozenOpenCLIPEmbedder"
):
...
...
@@ -1055,7 +1061,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model
.
load_model_weights
(
sd
,
"model.diffusion_model."
)
if
output_vae
:
vae
=
VAE
(
scale_factor
=
model_config
.
vae_scale_factor
)
vae
=
VAE
()
w
=
WeightsLoader
()
w
.
first_stage_model
=
vae
.
first_stage_model
load_model_weights
(
w
,
sd
)
...
...
comfy/supported_models.py
View file @
8607c2d4
...
...
@@ -7,6 +7,7 @@ from . import sd2_clip
from
.
import
sdxl_clip
from
.
import
supported_models_base
from
.
import
latent_formats
class
SD15
(
supported_models_base
.
BASE
):
unet_config
=
{
...
...
@@ -21,7 +22,7 @@ class SD15(supported_models_base.BASE):
"num_head_channels"
:
-
1
,
}
vae_scale_factor
=
0.182
15
latent_format
=
latent_formats
.
SD
15
def
process_clip_state_dict
(
self
,
state_dict
):
k
=
list
(
state_dict
.
keys
())
...
...
@@ -48,7 +49,7 @@ class SD20(supported_models_base.BASE):
"adm_in_channels"
:
None
,
}
vae_scale_factor
=
0.182
15
latent_format
=
latent_formats
.
SD
15
def
v_prediction
(
self
,
state_dict
):
if
self
.
unet_config
[
"in_channels"
]
==
4
:
#SD2.0 inpainting models are not v prediction
...
...
@@ -97,10 +98,10 @@ class SDXLRefiner(supported_models_base.BASE):
"transformer_depth"
:
[
0
,
4
,
4
,
0
],
}
vae_scale_factor
=
0.13025
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
):
return
model_base
.
SDXLRefiner
(
self
.
unet_config
)
return
model_base
.
SDXLRefiner
(
self
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
@@ -124,10 +125,10 @@ class SDXL(supported_models_base.BASE):
"adm_in_channels"
:
2816
}
vae_scale_factor
=
0.13025
latent_format
=
latent_formats
.
SDXL
def
get_model
(
self
,
state_dict
):
return
model_base
.
SDXL
(
self
.
unet_config
)
return
model_base
.
SDXL
(
self
)
def
process_clip_state_dict
(
self
,
state_dict
):
keys_to_replace
=
{}
...
...
comfy/supported_models_base.py
View file @
8607c2d4
...
...
@@ -49,16 +49,17 @@ class BASE:
def
__init__
(
self
,
unet_config
):
self
.
unet_config
=
unet_config
self
.
latent_format
=
self
.
latent_format
()
for
x
in
self
.
unet_extra_config
:
self
.
unet_config
[
x
]
=
self
.
unet_extra_config
[
x
]
def
get_model
(
self
,
state_dict
):
if
self
.
inpaint_model
():
return
model_base
.
SDInpaint
(
self
.
unet_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
SDInpaint
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
elif
self
.
noise_aug_config
is
not
None
:
return
model_base
.
SD21UNCLIP
(
self
.
unet_config
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
SD21UNCLIP
(
self
,
self
.
noise_aug_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
else
:
return
model_base
.
BaseModel
(
self
.
unet_config
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
return
model_base
.
BaseModel
(
self
,
v_prediction
=
self
.
v_prediction
(
state_dict
))
def
process_clip_state_dict
(
self
,
state_dict
):
return
state_dict
...
...
nodes.py
View file @
8607c2d4
...
...
@@ -284,6 +284,7 @@ class SaveLatent:
output
=
{}
output
[
"latent_tensor"
]
=
samples
[
"samples"
]
output
[
"latent_format_version_0"
]
=
torch
.
tensor
([])
safetensors
.
torch
.
save_file
(
output
,
file
,
metadata
=
metadata
)
...
...
@@ -305,7 +306,10 @@ class LoadLatent:
def
load
(
self
,
latent
):
latent_path
=
folder_paths
.
get_annotated_filepath
(
latent
)
latent
=
safetensors
.
torch
.
load_file
(
latent_path
,
device
=
"cpu"
)
samples
=
{
"samples"
:
latent
[
"latent_tensor"
].
float
()}
multiplier
=
1.0
if
"latent_format_version_0"
not
in
latent
:
multiplier
=
1.0
/
0.18215
samples
=
{
"samples"
:
latent
[
"latent_tensor"
].
float
()
*
multiplier
}
return
(
samples
,
)
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment