Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
f83109f0
Commit
f83109f0
authored
Feb 16, 2024
by
comfyanonymous
Browse files
Stable Cascade Stage C.
parent
5e06baf1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
619 additions
and
31 deletions
+619
-31
comfy/controlnet.py
comfy/controlnet.py
+13
-4
comfy/latent_formats.py
comfy/latent_formats.py
+6
-0
comfy/ldm/cascade/common.py
comfy/ldm/cascade/common.py
+161
-0
comfy/ldm/cascade/stage_c.py
comfy/ldm/cascade/stage_c.py
+271
-0
comfy/model_base.py
comfy/model_base.py
+37
-3
comfy/model_detection.py
comfy/model_detection.py
+23
-7
comfy/model_management.py
comfy/model_management.py
+29
-5
comfy/model_sampling.py
comfy/model_sampling.py
+30
-0
comfy/sd.py
comfy/sd.py
+11
-9
comfy/supported_models.py
comfy/supported_models.py
+34
-1
comfy/supported_models_base.py
comfy/supported_models_base.py
+4
-2
No files found.
comfy/controlnet.py
View file @
f83109f0
...
...
@@ -318,9 +318,10 @@ def load_controlnet(ckpt_path, model=None):
return
ControlLora
(
controlnet_data
)
controlnet_config
=
None
supported_inference_dtypes
=
None
if
"controlnet_cond_embedding.conv_in.weight"
in
controlnet_data
:
#diffusers format
unet_dtype
=
comfy
.
model_management
.
unet_dtype
()
controlnet_config
=
comfy
.
model_detection
.
unet_config_from_diffusers_unet
(
controlnet_data
,
unet_dtype
)
controlnet_config
=
comfy
.
model_detection
.
unet_config_from_diffusers_unet
(
controlnet_data
)
diffusers_keys
=
comfy
.
utils
.
unet_to_diffusers
(
controlnet_config
)
diffusers_keys
[
"controlnet_mid_block.weight"
]
=
"middle_block_out.0.weight"
diffusers_keys
[
"controlnet_mid_block.bias"
]
=
"middle_block_out.0.bias"
...
...
@@ -380,12 +381,20 @@ def load_controlnet(ckpt_path, model=None):
return
net
if
controlnet_config
is
None
:
unet_dtype
=
comfy
.
model_management
.
unet_dtype
()
controlnet_config
=
comfy
.
model_detection
.
model_config_from_unet
(
controlnet_data
,
prefix
,
unet_dtype
,
True
).
unet_config
model_config
=
comfy
.
model_detection
.
model_config_from_unet
(
controlnet_data
,
prefix
,
True
)
supported_inference_dtypes
=
model_config
.
supported_inference_dtypes
controlnet_config
=
model_config
.
unet_config
load_device
=
comfy
.
model_management
.
get_torch_device
()
if
supported_inference_dtypes
is
None
:
unet_dtype
=
comfy
.
model_management
.
unet_dtype
()
else
:
unet_dtype
=
comfy
.
model_management
.
unet_dtype
(
supported_dtypes
=
supported_inference_dtypes
)
manual_cast_dtype
=
comfy
.
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
if
manual_cast_dtype
is
not
None
:
controlnet_config
[
"operations"
]
=
comfy
.
ops
.
manual_cast
controlnet_config
[
"dtype"
]
=
unet_dtype
controlnet_config
.
pop
(
"out_channels"
)
controlnet_config
[
"hint_channels"
]
=
controlnet_data
[
"{}input_hint_block.0.weight"
.
format
(
prefix
)].
shape
[
1
]
control_model
=
comfy
.
cldm
.
cldm
.
ControlNet
(
**
controlnet_config
)
...
...
comfy/latent_formats.py
View file @
f83109f0
...
...
@@ -37,3 +37,9 @@ class SDXL(LatentFormat):
class
SD_X4
(
LatentFormat
):
def
__init__
(
self
):
self
.
scale_factor
=
0.08333
class
SC_Prior
(
LatentFormat
):
def
__init__
(
self
):
self
.
scale_factor
=
1.0
comfy/ldm/cascade/common.py
0 → 100644
View file @
f83109f0
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
import
torch.nn
as
nn
from
comfy.ldm.modules.attention
import
optimized_attention
class
Linear
(
torch
.
nn
.
Linear
):
def
reset_parameters
(
self
):
return
None
class
Conv2d
(
torch
.
nn
.
Conv2d
):
def
reset_parameters
(
self
):
return
None
class
OptimizedAttention
(
nn
.
Module
):
def
__init__
(
self
,
c
,
nhead
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
heads
=
nhead
self
.
to_q
=
operations
.
Linear
(
c
,
c
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
to_k
=
operations
.
Linear
(
c
,
c
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
to_v
=
operations
.
Linear
(
c
,
c
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
out_proj
=
operations
.
Linear
(
c
,
c
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
q
,
k
,
v
):
q
=
self
.
to_q
(
q
)
k
=
self
.
to_k
(
k
)
v
=
self
.
to_v
(
v
)
out
=
optimized_attention
(
q
,
k
,
v
,
self
.
heads
)
return
self
.
out_proj
(
out
)
class
Attention2D
(
nn
.
Module
):
def
__init__
(
self
,
c
,
nhead
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
attn
=
OptimizedAttention
(
c
,
nhead
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def
forward
(
self
,
x
,
kv
,
self_attn
=
False
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
x
.
size
(
0
),
x
.
size
(
1
),
-
1
).
permute
(
0
,
2
,
1
)
# Bx4xHxW -> Bx(HxW)x4
if
self_attn
:
kv
=
torch
.
cat
([
x
,
kv
],
dim
=
1
)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x
=
self
.
attn
(
x
,
kv
,
kv
)
x
=
x
.
permute
(
0
,
2
,
1
).
view
(
*
orig_shape
)
return
x
def
LayerNorm2d_op
(
operations
):
class
LayerNorm2d
(
operations
.
LayerNorm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x
):
return
super
().
forward
(
x
.
permute
(
0
,
2
,
3
,
1
)).
permute
(
0
,
3
,
1
,
2
)
return
LayerNorm2d
class
GlobalResponseNorm
(
nn
.
Module
):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def
__init__
(
self
,
dim
,
dtype
=
None
,
device
=
None
):
super
().
__init__
()
self
.
gamma
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
self
.
beta
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
dim
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
):
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
return
self
.
gamma
*
(
x
*
Nx
)
+
self
.
beta
+
x
class
ResBlock
(
nn
.
Module
):
def
__init__
(
self
,
c
,
c_skip
=
0
,
kernel_size
=
3
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
# , num_heads=4, expansion=2):
super
().
__init__
()
self
.
depthwise
=
operations
.
Conv2d
(
c
,
c
,
kernel_size
=
kernel_size
,
padding
=
kernel_size
//
2
,
groups
=
c
,
dtype
=
dtype
,
device
=
device
)
# self.depthwise = SAMBlock(c, num_heads, expansion)
self
.
norm
=
LayerNorm2d_op
(
operations
)(
c
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
channelwise
=
nn
.
Sequential
(
operations
.
Linear
(
c
+
c_skip
,
c
*
4
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
GlobalResponseNorm
(
c
*
4
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
),
operations
.
Linear
(
c
*
4
,
c
,
dtype
=
dtype
,
device
=
device
)
)
def
forward
(
self
,
x
,
x_skip
=
None
):
x_res
=
x
x
=
self
.
norm
(
self
.
depthwise
(
x
))
if
x_skip
is
not
None
:
x
=
torch
.
cat
([
x
,
x_skip
],
dim
=
1
)
x
=
self
.
channelwise
(
x
.
permute
(
0
,
2
,
3
,
1
)).
permute
(
0
,
3
,
1
,
2
)
return
x
+
x_res
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
c
,
c_cond
,
nhead
,
self_attn
=
True
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
norm
=
LayerNorm2d_op
(
operations
)(
c
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
attention
=
Attention2D
(
c
,
nhead
,
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
self
.
kv_mapper
=
nn
.
Sequential
(
nn
.
SiLU
(),
operations
.
Linear
(
c_cond
,
c
,
dtype
=
dtype
,
device
=
device
)
)
def
forward
(
self
,
x
,
kv
):
kv
=
self
.
kv_mapper
(
kv
)
x
=
x
+
self
.
attention
(
self
.
norm
(
x
),
kv
,
self_attn
=
self
.
self_attn
)
return
x
class
FeedForwardBlock
(
nn
.
Module
):
def
__init__
(
self
,
c
,
dropout
=
0.0
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
norm
=
LayerNorm2d_op
(
operations
)(
c
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
channelwise
=
nn
.
Sequential
(
operations
.
Linear
(
c
,
c
*
4
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
GlobalResponseNorm
(
c
*
4
,
dtype
=
dtype
,
device
=
device
),
nn
.
Dropout
(
dropout
),
operations
.
Linear
(
c
*
4
,
c
,
dtype
=
dtype
,
device
=
device
)
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
channelwise
(
self
.
norm
(
x
).
permute
(
0
,
2
,
3
,
1
)).
permute
(
0
,
3
,
1
,
2
)
return
x
class
TimestepBlock
(
nn
.
Module
):
def
__init__
(
self
,
c
,
c_timestep
,
conds
=
[
'sca'
],
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
mapper
=
operations
.
Linear
(
c_timestep
,
c
*
2
,
dtype
=
dtype
,
device
=
device
)
self
.
conds
=
conds
for
cname
in
conds
:
setattr
(
self
,
f
"mapper_
{
cname
}
"
,
operations
.
Linear
(
c_timestep
,
c
*
2
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
,
t
):
t
=
t
.
chunk
(
len
(
self
.
conds
)
+
1
,
dim
=
1
)
a
,
b
=
self
.
mapper
(
t
[
0
])[:,
:,
None
,
None
].
chunk
(
2
,
dim
=
1
)
for
i
,
c
in
enumerate
(
self
.
conds
):
ac
,
bc
=
getattr
(
self
,
f
"mapper_
{
c
}
"
)(
t
[
i
+
1
])[:,
:,
None
,
None
].
chunk
(
2
,
dim
=
1
)
a
,
b
=
a
+
ac
,
b
+
bc
return
x
*
(
1
+
a
)
+
b
comfy/ldm/cascade/stage_c.py
0 → 100644
View file @
f83109f0
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
from
torch
import
nn
import
numpy
as
np
import
math
from
.common
import
AttnBlock
,
LayerNorm2d_op
,
ResBlock
,
FeedForwardBlock
,
TimestepBlock
# from .controlnet import ControlNetDeliverer
class
UpDownBlock2d
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_out
,
mode
,
enabled
=
True
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
assert
mode
in
[
'up'
,
'down'
]
interpolation
=
nn
.
Upsample
(
scale_factor
=
2
if
mode
==
'up'
else
0.5
,
mode
=
'bilinear'
,
align_corners
=
True
)
if
enabled
else
nn
.
Identity
()
mapping
=
operations
.
Conv2d
(
c_in
,
c_out
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
)
self
.
blocks
=
nn
.
ModuleList
([
interpolation
,
mapping
]
if
mode
==
'up'
else
[
mapping
,
interpolation
])
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
x
=
block
(
x
)
return
x
class
StageC
(
nn
.
Module
):
def
__init__
(
self
,
c_in
=
16
,
c_out
=
16
,
c_r
=
64
,
patch_size
=
1
,
c_cond
=
2048
,
c_hidden
=
[
2048
,
2048
],
nhead
=
[
32
,
32
],
blocks
=
[[
8
,
24
],
[
24
,
8
]],
block_repeat
=
[[
1
,
1
],
[
1
,
1
]],
level_config
=
[
'CTA'
,
'CTA'
],
c_clip_text
=
1280
,
c_clip_text_pooled
=
1280
,
c_clip_img
=
768
,
c_clip_seq
=
4
,
kernel_size
=
3
,
dropout
=
[
0.1
,
0.1
],
self_attn
=
True
,
t_conds
=
[
'sca'
,
'crp'
],
switch_level
=
[
False
],
stable_cascade_stage
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
dtype
=
dtype
self
.
c_r
=
c_r
self
.
t_conds
=
t_conds
self
.
c_clip_seq
=
c_clip_seq
if
not
isinstance
(
dropout
,
list
):
dropout
=
[
dropout
]
*
len
(
c_hidden
)
if
not
isinstance
(
self_attn
,
list
):
self_attn
=
[
self_attn
]
*
len
(
c_hidden
)
# CONDITIONING
self
.
clip_txt_mapper
=
operations
.
Linear
(
c_clip_text
,
c_cond
,
dtype
=
dtype
,
device
=
device
)
self
.
clip_txt_pooled_mapper
=
operations
.
Linear
(
c_clip_text_pooled
,
c_cond
*
c_clip_seq
,
dtype
=
dtype
,
device
=
device
)
self
.
clip_img_mapper
=
operations
.
Linear
(
c_clip_img
,
c_cond
*
c_clip_seq
,
dtype
=
dtype
,
device
=
device
)
self
.
clip_norm
=
operations
.
LayerNorm
(
c_cond
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
embedding
=
nn
.
Sequential
(
nn
.
PixelUnshuffle
(
patch_size
),
operations
.
Conv2d
(
c_in
*
(
patch_size
**
2
),
c_hidden
[
0
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
)
)
def
get_block
(
block_type
,
c_hidden
,
nhead
,
c_skip
=
0
,
dropout
=
0
,
self_attn
=
True
):
if
block_type
==
'C'
:
return
ResBlock
(
c_hidden
,
c_skip
,
kernel_size
=
kernel_size
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'A'
:
return
AttnBlock
(
c_hidden
,
c_cond
,
nhead
,
self_attn
=
self_attn
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'F'
:
return
FeedForwardBlock
(
c_hidden
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'T'
:
return
TimestepBlock
(
c_hidden
,
c_r
,
conds
=
t_conds
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
else
:
raise
Exception
(
f
'Block type
{
block_type
}
not supported'
)
# BLOCKS
# -- down blocks
self
.
down_blocks
=
nn
.
ModuleList
()
self
.
down_downscalers
=
nn
.
ModuleList
()
self
.
down_repeat_mappers
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
c_hidden
)):
if
i
>
0
:
self
.
down_downscalers
.
append
(
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
i
-
1
],
elementwise_affine
=
False
,
eps
=
1e-6
),
UpDownBlock2d
(
c_hidden
[
i
-
1
],
c_hidden
[
i
],
mode
=
'down'
,
enabled
=
switch_level
[
i
-
1
],
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
))
else
:
self
.
down_downscalers
.
append
(
nn
.
Identity
())
down_block
=
nn
.
ModuleList
()
for
_
in
range
(
blocks
[
0
][
i
]):
for
block_type
in
level_config
[
i
]:
block
=
get_block
(
block_type
,
c_hidden
[
i
],
nhead
[
i
],
dropout
=
dropout
[
i
],
self_attn
=
self_attn
[
i
])
down_block
.
append
(
block
)
self
.
down_blocks
.
append
(
down_block
)
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
0
][
i
]
-
1
):
block_repeat_mappers
.
append
(
nn
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
))
self
.
down_repeat_mappers
.
append
(
block_repeat_mappers
)
# -- up blocks
self
.
up_blocks
=
nn
.
ModuleList
()
self
.
up_upscalers
=
nn
.
ModuleList
()
self
.
up_repeat_mappers
=
nn
.
ModuleList
()
for
i
in
reversed
(
range
(
len
(
c_hidden
))):
if
i
>
0
:
self
.
up_upscalers
.
append
(
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
i
],
elementwise_affine
=
False
,
eps
=
1e-6
),
UpDownBlock2d
(
c_hidden
[
i
],
c_hidden
[
i
-
1
],
mode
=
'up'
,
enabled
=
switch_level
[
i
-
1
],
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
))
else
:
self
.
up_upscalers
.
append
(
nn
.
Identity
())
up_block
=
nn
.
ModuleList
()
for
j
in
range
(
blocks
[
1
][::
-
1
][
i
]):
for
k
,
block_type
in
enumerate
(
level_config
[
i
]):
c_skip
=
c_hidden
[
i
]
if
i
<
len
(
c_hidden
)
-
1
and
j
==
k
==
0
else
0
block
=
get_block
(
block_type
,
c_hidden
[
i
],
nhead
[
i
],
c_skip
=
c_skip
,
dropout
=
dropout
[
i
],
self_attn
=
self_attn
[
i
])
up_block
.
append
(
block
)
self
.
up_blocks
.
append
(
up_block
)
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
1
][::
-
1
][
i
]
-
1
):
block_repeat_mappers
.
append
(
nn
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
))
self
.
up_repeat_mappers
.
append
(
block_repeat_mappers
)
# OUTPUT
self
.
clf
=
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
),
operations
.
Conv2d
(
c_hidden
[
0
],
c_out
*
(
patch_size
**
2
),
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
PixelShuffle
(
patch_size
),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def
gen_r_embedding
(
self
,
r
,
max_positions
=
10000
):
r
=
r
*
max_positions
half_dim
=
self
.
c_r
//
2
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
emb
=
torch
.
arange
(
half_dim
,
device
=
r
.
device
).
float
().
mul
(
-
emb
).
exp
()
emb
=
r
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
emb
.
sin
(),
emb
.
cos
()],
dim
=
1
)
if
self
.
c_r
%
2
==
1
:
# zero pad
emb
=
nn
.
functional
.
pad
(
emb
,
(
0
,
1
),
mode
=
'constant'
)
return
emb
def
gen_c_embeddings
(
self
,
clip_txt
,
clip_txt_pooled
,
clip_img
):
clip_txt
=
self
.
clip_txt_mapper
(
clip_txt
)
if
len
(
clip_txt_pooled
.
shape
)
==
2
:
clip_txt_pooled
=
clip_txt_pooled
.
unsqueeze
(
1
)
if
len
(
clip_img
.
shape
)
==
2
:
clip_img
=
clip_img
.
unsqueeze
(
1
)
clip_txt_pool
=
self
.
clip_txt_pooled_mapper
(
clip_txt_pooled
).
view
(
clip_txt_pooled
.
size
(
0
),
clip_txt_pooled
.
size
(
1
)
*
self
.
c_clip_seq
,
-
1
)
clip_img
=
self
.
clip_img_mapper
(
clip_img
).
view
(
clip_img
.
size
(
0
),
clip_img
.
size
(
1
)
*
self
.
c_clip_seq
,
-
1
)
clip
=
torch
.
cat
([
clip_txt
,
clip_txt_pool
,
clip_img
],
dim
=
1
)
clip
=
self
.
clip_norm
(
clip
)
return
clip
def
_down_encode
(
self
,
x
,
r_embed
,
clip
,
cnet
=
None
):
level_outputs
=
[]
block_group
=
zip
(
self
.
down_blocks
,
self
.
down_downscalers
,
self
.
down_repeat_mappers
)
for
down_block
,
downscaler
,
repmap
in
block_group
:
x
=
downscaler
(
x
)
for
i
in
range
(
len
(
repmap
)
+
1
):
for
block
in
down_block
:
if
isinstance
(
block
,
ResBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
ResBlock
)):
if
cnet
is
not
None
:
next_cnet
=
cnet
()
if
next_cnet
is
not
None
:
x
=
x
+
nn
.
functional
.
interpolate
(
next_cnet
,
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
x
=
block
(
x
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
AttnBlock
)):
x
=
block
(
x
,
clip
)
elif
isinstance
(
block
,
TimestepBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
TimestepBlock
)):
x
=
block
(
x
,
r_embed
)
else
:
x
=
block
(
x
)
if
i
<
len
(
repmap
):
x
=
repmap
[
i
](
x
)
level_outputs
.
insert
(
0
,
x
)
return
level_outputs
def
_up_decode
(
self
,
level_outputs
,
r_embed
,
clip
,
cnet
=
None
):
x
=
level_outputs
[
0
]
block_group
=
zip
(
self
.
up_blocks
,
self
.
up_upscalers
,
self
.
up_repeat_mappers
)
for
i
,
(
up_block
,
upscaler
,
repmap
)
in
enumerate
(
block_group
):
for
j
in
range
(
len
(
repmap
)
+
1
):
for
k
,
block
in
enumerate
(
up_block
):
if
isinstance
(
block
,
ResBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
ResBlock
)):
skip
=
level_outputs
[
i
]
if
k
==
0
and
i
>
0
else
None
if
skip
is
not
None
and
(
x
.
size
(
-
1
)
!=
skip
.
size
(
-
1
)
or
x
.
size
(
-
2
)
!=
skip
.
size
(
-
2
)):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
skip
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
if
cnet
is
not
None
:
next_cnet
=
cnet
()
if
next_cnet
is
not
None
:
x
=
x
+
nn
.
functional
.
interpolate
(
next_cnet
,
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
x
=
block
(
x
,
skip
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
AttnBlock
)):
x
=
block
(
x
,
clip
)
elif
isinstance
(
block
,
TimestepBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
TimestepBlock
)):
x
=
block
(
x
,
r_embed
)
else
:
x
=
block
(
x
)
if
j
<
len
(
repmap
):
x
=
repmap
[
j
](
x
)
x
=
upscaler
(
x
)
return
x
def
forward
(
self
,
x
,
r
,
clip_text
,
clip_text_pooled
,
clip_img
,
cnet
=
None
,
**
kwargs
):
# Process the conditioning embeddings
r_embed
=
self
.
gen_r_embedding
(
r
).
to
(
dtype
=
x
.
dtype
)
for
c
in
self
.
t_conds
:
t_cond
=
kwargs
.
get
(
c
,
torch
.
zeros_like
(
r
))
r_embed
=
torch
.
cat
([
r_embed
,
self
.
gen_r_embedding
(
t_cond
).
to
(
dtype
=
x
.
dtype
)],
dim
=
1
)
clip
=
self
.
gen_c_embeddings
(
clip_text
,
clip_text_pooled
,
clip_img
)
# Model Blocks
x
=
self
.
embedding
(
x
)
if
cnet
is
not
None
:
cnet
=
ControlNetDeliverer
(
cnet
)
level_outputs
=
self
.
_down_encode
(
x
,
r_embed
,
clip
,
cnet
)
x
=
self
.
_up_decode
(
level_outputs
,
r_embed
,
clip
,
cnet
)
return
self
.
clf
(
x
)
def
update_weights_ema
(
self
,
src_model
,
beta
=
0.999
):
for
self_params
,
src_params
in
zip
(
self
.
parameters
(),
src_model
.
parameters
()):
self_params
.
data
=
self_params
.
data
*
beta
+
src_params
.
data
.
clone
().
to
(
self_params
.
device
)
*
(
1
-
beta
)
for
self_buffers
,
src_buffers
in
zip
(
self
.
buffers
(),
src_model
.
buffers
()):
self_buffers
.
data
=
self_buffers
.
data
*
beta
+
src_buffers
.
data
.
clone
().
to
(
self_buffers
.
device
)
*
(
1
-
beta
)
comfy/model_base.py
View file @
f83109f0
import
torch
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
Timestep
from
comfy.ldm.cascade.stage_c
import
StageC
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.upscaling
import
ImageConcatWithNoiseAugmentation
import
comfy.model_management
...
...
@@ -12,9 +13,10 @@ class ModelType(Enum):
EPS
=
1
V_PREDICTION
=
2
V_PREDICTION_EDM
=
3
STABLE_CASCADE
=
4
from
comfy.model_sampling
import
EPS
,
V_PREDICTION
,
ModelSamplingDiscrete
,
ModelSamplingContinuousEDM
from
comfy.model_sampling
import
EPS
,
V_PREDICTION
,
ModelSamplingDiscrete
,
ModelSamplingContinuousEDM
,
StableCascadeSampling
def
model_sampling
(
model_config
,
model_type
):
...
...
@@ -27,6 +29,9 @@ def model_sampling(model_config, model_type):
elif
model_type
==
ModelType
.
V_PREDICTION_EDM
:
c
=
V_PREDICTION
s
=
ModelSamplingContinuousEDM
elif
model_type
==
ModelType
.
STABLE_CASCADE
:
c
=
EPS
s
=
StableCascadeSampling
class
ModelSampling
(
s
,
c
):
pass
...
...
@@ -35,7 +40,7 @@ def model_sampling(model_config, model_type):
class
BaseModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
EPS
,
device
=
None
,
unet_model
=
UNetModel
):
super
().
__init__
()
unet_config
=
model_config
.
unet_config
...
...
@@ -48,7 +53,7 @@ class BaseModel(torch.nn.Module):
operations
=
comfy
.
ops
.
manual_cast
else
:
operations
=
comfy
.
ops
.
disable_weight_init
self
.
diffusion_model
=
UNetM
odel
(
**
unet_config
,
device
=
device
,
operations
=
operations
)
self
.
diffusion_model
=
unet_m
odel
(
**
unet_config
,
device
=
device
,
operations
=
operations
)
self
.
model_type
=
model_type
self
.
model_sampling
=
model_sampling
(
model_config
,
model_type
)
...
...
@@ -427,3 +432,32 @@ class SD_X4Upscaler(BaseModel):
out
[
'c_concat'
]
=
comfy
.
conds
.
CONDNoiseShape
(
image
)
out
[
'y'
]
=
comfy
.
conds
.
CONDRegular
(
noise_level
)
return
out
class
StableCascade_C
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
STABLE_CASCADE
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
,
unet_model
=
StageC
)
self
.
diffusion_model
.
eval
().
requires_grad_
(
False
)
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
clip_text_pooled
=
kwargs
[
"pooled_output"
]
if
clip_text_pooled
is
not
None
:
out
[
'clip_text_pooled'
]
=
comfy
.
conds
.
CONDRegular
(
clip_text_pooled
)
if
"unclip_conditioning"
in
kwargs
:
embeds
=
[]
for
unclip_cond
in
kwargs
[
"unclip_conditioning"
]:
weight
=
unclip_cond
[
"strength"
]
embeds
.
append
(
unclip_cond
[
"clip_vision_output"
].
image_embeds
.
unsqueeze
(
0
)
*
weight
)
clip_img
=
torch
.
cat
(
embeds
,
dim
=
1
)
else
:
clip_img
=
torch
.
zeros
((
1
,
1
,
768
))
out
[
"clip_img"
]
=
comfy
.
conds
.
CONDRegular
(
clip_img
)
out
[
"sca"
]
=
comfy
.
conds
.
CONDRegular
(
torch
.
zeros
((
1
,)))
out
[
"crp"
]
=
comfy
.
conds
.
CONDRegular
(
torch
.
zeros
((
1
,)))
cross_attn
=
kwargs
.
get
(
"cross_attn"
,
None
)
if
cross_attn
is
not
None
:
out
[
'clip_text'
]
=
comfy
.
conds
.
CONDCrossAttn
(
cross_attn
)
return
out
comfy/model_detection.py
View file @
f83109f0
...
...
@@ -28,9 +28,26 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
,
time_stack
return
None
def
detect_unet_config
(
state_dict
,
key_prefix
,
dtype
):
def
detect_unet_config
(
state_dict
,
key_prefix
):
state_dict_keys
=
list
(
state_dict
.
keys
())
if
'{}clf.1.weight'
.
format
(
key_prefix
)
in
state_dict_keys
:
#stable cascade
unet_config
=
{}
text_mapper_name
=
'{}clip_txt_mapper.weight'
.
format
(
key_prefix
)
if
text_mapper_name
in
state_dict_keys
:
unet_config
[
'stable_cascade_stage'
]
=
'c'
w
=
state_dict
[
text_mapper_name
]
if
w
.
shape
[
0
]
==
1536
:
#stage c lite
unet_config
[
'c_cond'
]
=
1536
unet_config
[
'c_hidden'
]
=
[
1536
,
1536
]
unet_config
[
'nhead'
]
=
[
24
,
24
]
unet_config
[
'blocks'
]
=
[[
4
,
12
],
[
12
,
4
]]
elif
w
.
shape
[
0
]
==
2048
:
#stage c full
unet_config
[
'c_cond'
]
=
2048
elif
'{}clip_mapper.weight'
.
format
(
key_prefix
)
in
state_dict_keys
:
unet_config
[
'stable_cascade_stage'
]
=
'b'
return
unet_config
unet_config
=
{
"use_checkpoint"
:
False
,
"image_size"
:
32
,
...
...
@@ -45,7 +62,6 @@ def detect_unet_config(state_dict, key_prefix, dtype):
else
:
unet_config
[
"adm_in_channels"
]
=
None
unet_config
[
"dtype"
]
=
dtype
model_channels
=
state_dict
[
'{}input_blocks.0.0.weight'
.
format
(
key_prefix
)].
shape
[
0
]
in_channels
=
state_dict
[
'{}input_blocks.0.0.weight'
.
format
(
key_prefix
)].
shape
[
1
]
...
...
@@ -159,8 +175,8 @@ def model_config_from_unet_config(unet_config):
print
(
"no match"
,
unet_config
)
return
None
def
model_config_from_unet
(
state_dict
,
unet_key_prefix
,
dtype
,
use_base_if_no_match
=
False
):
unet_config
=
detect_unet_config
(
state_dict
,
unet_key_prefix
,
dtype
)
def
model_config_from_unet
(
state_dict
,
unet_key_prefix
,
use_base_if_no_match
=
False
):
unet_config
=
detect_unet_config
(
state_dict
,
unet_key_prefix
)
model_config
=
model_config_from_unet_config
(
unet_config
)
if
model_config
is
None
and
use_base_if_no_match
:
return
comfy
.
supported_models_base
.
BASE
(
unet_config
)
...
...
@@ -206,7 +222,7 @@ def convert_config(unet_config):
return
new_config
def
unet_config_from_diffusers_unet
(
state_dict
,
dtype
):
def
unet_config_from_diffusers_unet
(
state_dict
,
dtype
=
None
):
match
=
{}
transformer_depth
=
[]
...
...
@@ -313,8 +329,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
return
convert_config
(
unet_config
)
return
None
def
model_config_from_diffusers_unet
(
state_dict
,
dtype
):
unet_config
=
unet_config_from_diffusers_unet
(
state_dict
,
dtype
)
def
model_config_from_diffusers_unet
(
state_dict
):
unet_config
=
unet_config_from_diffusers_unet
(
state_dict
)
if
unet_config
is
not
None
:
return
model_config_from_unet_config
(
unet_config
)
return
None
comfy/model_management.py
View file @
f83109f0
...
...
@@ -487,7 +487,7 @@ def unet_inital_load_device(parameters, dtype):
else
:
return
cpu_dev
def
unet_dtype
(
device
=
None
,
model_params
=
0
):
def
unet_dtype
(
device
=
None
,
model_params
=
0
,
supported_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
):
if
args
.
bf16_unet
:
return
torch
.
bfloat16
if
args
.
fp16_unet
:
...
...
@@ -497,20 +497,31 @@ def unet_dtype(device=None, model_params=0):
if
args
.
fp8_e5m2_unet
:
return
torch
.
float8_e5m2
if
should_use_fp16
(
device
=
device
,
model_params
=
model_params
,
manual_cast
=
True
):
return
torch
.
float16
if
torch
.
float16
in
supported_dtypes
:
return
torch
.
float16
if
should_use_bf16
(
device
):
if
torch
.
bfloat16
in
supported_dtypes
:
return
torch
.
bfloat16
return
torch
.
float32
# None means no manual cast
def
unet_manual_cast
(
weight_dtype
,
inference_device
):
def
unet_manual_cast
(
weight_dtype
,
inference_device
,
supported_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
):
if
weight_dtype
==
torch
.
float32
:
return
None
fp16_supported
=
comfy
.
model_management
.
should_use_fp16
(
inference_device
,
prioritize_performance
=
False
)
fp16_supported
=
should_use_fp16
(
inference_device
,
prioritize_performance
=
False
)
if
fp16_supported
and
weight_dtype
==
torch
.
float16
:
return
None
if
fp16_supported
:
bf16_supported
=
should_use_bf16
(
inference_device
)
if
bf16_supported
and
weight_dtype
==
torch
.
bfloat16
:
return
None
if
fp16_supported
and
torch
.
float16
in
supported_dtypes
:
return
torch
.
float16
elif
bf16_supported
and
torch
.
bfloat16
in
supported_dtypes
:
return
torch
.
bfloat16
else
:
return
torch
.
float32
...
...
@@ -760,6 +771,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return
True
def
should_use_bf16
(
device
=
None
):
if
is_intel_xpu
():
return
True
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
)
props
=
torch
.
cuda
.
get_device_properties
(
device
)
if
props
.
major
>=
8
:
return
True
return
False
def
soft_empty_cache
(
force
=
False
):
global
cpu_state
if
cpu_state
==
CPUState
.
MPS
:
...
...
comfy/model_sampling.py
View file @
f83109f0
...
...
@@ -132,3 +132,33 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min
=
math
.
log
(
self
.
sigma_min
)
return
math
.
exp
((
math
.
log
(
self
.
sigma_max
)
-
log_sigma_min
)
*
percent
+
log_sigma_min
)
class
StableCascadeSampling
(
ModelSamplingDiscrete
):
def
__init__
(
self
,
model_config
=
None
):
super
().
__init__
()
self
.
num_timesteps
=
1000
cosine_s
=
8e-3
self
.
cosine_s
=
torch
.
tensor
([
cosine_s
])
sigmas
=
torch
.
empty
((
self
.
num_timesteps
),
dtype
=
torch
.
float32
)
self
.
_init_alpha_cumprod
=
torch
.
cos
(
self
.
cosine_s
/
(
1
+
self
.
cosine_s
)
*
torch
.
pi
*
0.5
)
**
2
for
x
in
range
(
self
.
num_timesteps
):
t
=
x
/
self
.
num_timesteps
sigmas
[
x
]
=
self
.
sigma
(
t
)
self
.
set_sigmas
(
sigmas
)
def
sigma
(
self
,
timestep
):
alpha_cumprod
=
(
torch
.
cos
((
timestep
+
self
.
cosine_s
)
/
(
1
+
self
.
cosine_s
)
*
torch
.
pi
*
0.5
)
**
2
/
self
.
_init_alpha_cumprod
).
clamp
(
0.0001
,
0.9999
)
return
((
1
-
alpha_cumprod
)
/
alpha_cumprod
)
**
0.5
def
timestep
(
self
,
sigma
):
return
super
().
timestep
(
sigma
)
/
1000.0
def
percent_to_sigma
(
self
,
percent
):
if
percent
<=
0.0
:
return
999999999.9
if
percent
>=
1.0
:
return
0.0
percent
=
1.0
-
percent
return
self
.
sigma
(
torch
.
tensor
(
percent
))
comfy/sd.py
View file @
f83109f0
...
...
@@ -450,15 +450,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
clip_target
=
None
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
,
"model.diffusion_model."
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
class
WeightsLoader
(
torch
.
nn
.
Module
):
pass
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
"model.diffusion_model."
,
unet_dtype
)
model_config
.
set_manual_cast
(
manual_cast_dtype
)
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
"model.diffusion_model."
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
,
supported_dtypes
=
model_config
.
supported_inference_dtypes
)
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
,
model_config
.
supported_inference_dtypes
)
model_config
.
set_inference_dtype
(
unet_dtype
,
manual_cast_dtype
)
if
model_config
is
None
:
raise
RuntimeError
(
"ERROR: Could not detect model type of: {}"
.
format
(
ckpt_path
))
...
...
@@ -507,16 +507,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format
parameters
=
comfy
.
utils
.
calculate_parameters
(
sd
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
)
load_device
=
model_management
.
get_torch_device
()
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
)
if
"input_blocks.0.0.weight"
in
sd
:
#ldm
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
,
unet_dtype
)
if
"input_blocks.0.0.weight"
in
sd
or
'clf.1.weight'
in
sd
:
#ldm or stable cascade
model_config
=
model_detection
.
model_config_from_unet
(
sd
,
""
)
if
model_config
is
None
:
return
None
new_sd
=
sd
else
:
#diffusers
model_config
=
model_detection
.
model_config_from_diffusers_unet
(
sd
,
unet_dtype
)
model_config
=
model_detection
.
model_config_from_diffusers_unet
(
sd
)
if
model_config
is
None
:
return
None
...
...
@@ -528,8 +527,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format
new_sd
[
diffusers_keys
[
k
]]
=
sd
.
pop
(
k
)
else
:
print
(
diffusers_keys
[
k
],
k
)
offload_device
=
model_management
.
unet_offload_device
()
model_config
.
set_manual_cast
(
manual_cast_dtype
)
unet_dtype
=
model_management
.
unet_dtype
(
model_params
=
parameters
,
supported_dtypes
=
model_config
.
supported_inference_dtypes
)
manual_cast_dtype
=
model_management
.
unet_manual_cast
(
unet_dtype
,
load_device
,
model_config
.
supported_inference_dtypes
)
model_config
.
set_inference_dtype
(
unet_dtype
,
manual_cast_dtype
)
model
=
model_config
.
get_model
(
new_sd
,
""
)
model
=
model
.
to
(
offload_device
)
model
.
load_model_weights
(
new_sd
,
""
)
...
...
comfy/supported_models.py
View file @
f83109f0
...
...
@@ -306,5 +306,38 @@ class SD_X4Upscaler(SD20):
out
=
model_base
.
SD_X4Upscaler
(
self
,
device
=
device
)
return
out
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
Segmind_Vega
,
SD_X4Upscaler
]
class
Stable_Cascade_C
(
supported_models_base
.
BASE
):
unet_config
=
{
"stable_cascade_stage"
:
'c'
,
}
unet_extra_config
=
{}
latent_format
=
latent_formats
.
SC_Prior
supported_inference_dtypes
=
[
torch
.
bfloat16
,
torch
.
float32
]
def
process_unet_state_dict
(
self
,
state_dict
):
key_list
=
list
(
state_dict
.
keys
())
for
y
in
[
"weight"
,
"bias"
]:
suffix
=
"in_proj_{}"
.
format
(
y
)
keys
=
filter
(
lambda
a
:
a
.
endswith
(
suffix
),
key_list
)
for
k_from
in
keys
:
weights
=
state_dict
.
pop
(
k_from
)
prefix
=
k_from
[:
-
(
len
(
suffix
)
+
1
)]
shape_from
=
weights
.
shape
[
0
]
//
3
for
x
in
range
(
3
):
p
=
[
"to_q"
,
"to_k"
,
"to_v"
]
k_to
=
"{}.{}.{}"
.
format
(
prefix
,
p
[
x
],
y
)
state_dict
[
k_to
]
=
weights
[
shape_from
*
x
:
shape_from
*
(
x
+
1
)]
return
state_dict
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
StableCascade_C
(
self
,
device
=
device
)
return
out
def
clip_target
(
self
):
return
None
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
]
models
+=
[
SVD_img2vid
]
comfy/supported_models_base.py
View file @
f83109f0
...
...
@@ -22,13 +22,14 @@ class BASE:
sampling_settings
=
{}
latent_format
=
latent_formats
.
LatentFormat
vae_key_prefix
=
[
"first_stage_model."
]
supported_inference_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
manual_cast_dtype
=
None
@
classmethod
def
matches
(
s
,
unet_config
):
for
k
in
s
.
unet_config
:
if
s
.
unet_config
[
k
]
!=
unet_config
[
k
]:
if
k
not
in
unet_config
or
s
.
unet_config
[
k
]
!=
unet_config
[
k
]:
return
False
return
True
...
...
@@ -80,5 +81,6 @@ class BASE:
replace_prefix
=
{
""
:
"first_stage_model."
}
return
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
def
set_manual_cast
(
self
,
manual_cast_dtype
):
def
set_inference_dtype
(
self
,
dtype
,
manual_cast_dtype
):
self
.
unet_config
[
'dtype'
]
=
dtype
self
.
manual_cast_dtype
=
manual_cast_dtype
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment