Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
667c9281
Commit
667c9281
authored
Feb 16, 2024
by
comfyanonymous
Browse files
Stable Cascade Stage B.
parent
f83109f0
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
430 additions
and
8 deletions
+430
-8
comfy/latent_formats.py
comfy/latent_formats.py
+3
-1
comfy/ldm/cascade/common.py
comfy/ldm/cascade/common.py
+1
-1
comfy/ldm/cascade/stage_b.py
comfy/ldm/cascade/stage_b.py
+257
-0
comfy/ldm/cascade/stage_c.py
comfy/ldm/cascade/stage_c.py
+4
-4
comfy/model_base.py
comfy/model_base.py
+25
-0
comfy/ops.py
comfy/ops.py
+48
-1
comfy/supported_models.py
comfy/supported_models.py
+15
-1
comfy/utils.py
comfy/utils.py
+2
-0
comfy_extras/nodes_stable_cascade.py
comfy_extras/nodes_stable_cascade.py
+74
-0
nodes.py
nodes.py
+1
-0
No files found.
comfy/latent_formats.py
View file @
667c9281
...
@@ -42,4 +42,6 @@ class SC_Prior(LatentFormat):
...
@@ -42,4 +42,6 @@ class SC_Prior(LatentFormat):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
scale_factor
=
1.0
self
.
scale_factor
=
1.0
class
SC_B
(
LatentFormat
):
def
__init__
(
self
):
self
.
scale_factor
=
1.0
comfy/ldm/cascade/common.py
View file @
667c9281
...
@@ -84,7 +84,7 @@ class GlobalResponseNorm(nn.Module):
...
@@ -84,7 +84,7 @@ class GlobalResponseNorm(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
return
self
.
gamma
*
(
x
*
Nx
)
+
self
.
beta
+
x
return
self
.
gamma
.
to
(
x
.
device
)
*
(
x
*
Nx
)
+
self
.
beta
.
to
(
x
.
device
)
+
x
class
ResBlock
(
nn
.
Module
):
class
ResBlock
(
nn
.
Module
):
...
...
comfy/ldm/cascade/stage_b.py
0 → 100644
View file @
667c9281
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
.common
import
AttnBlock
,
LayerNorm2d_op
,
ResBlock
,
FeedForwardBlock
,
TimestepBlock
class
StageB
(
nn
.
Module
):
def
__init__
(
self
,
c_in
=
4
,
c_out
=
4
,
c_r
=
64
,
patch_size
=
2
,
c_cond
=
1280
,
c_hidden
=
[
320
,
640
,
1280
,
1280
],
nhead
=
[
-
1
,
-
1
,
20
,
20
],
blocks
=
[[
2
,
6
,
28
,
6
],
[
6
,
28
,
6
,
2
]],
block_repeat
=
[[
1
,
1
,
1
,
1
],
[
3
,
3
,
2
,
2
]],
level_config
=
[
'CT'
,
'CT'
,
'CTA'
,
'CTA'
],
c_clip
=
1280
,
c_clip_seq
=
4
,
c_effnet
=
16
,
c_pixels
=
3
,
kernel_size
=
3
,
dropout
=
[
0
,
0
,
0.0
,
0.0
],
self_attn
=
True
,
t_conds
=
[
'sca'
],
stable_cascade_stage
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
dtype
=
dtype
self
.
c_r
=
c_r
self
.
t_conds
=
t_conds
self
.
c_clip_seq
=
c_clip_seq
if
not
isinstance
(
dropout
,
list
):
dropout
=
[
dropout
]
*
len
(
c_hidden
)
if
not
isinstance
(
self_attn
,
list
):
self_attn
=
[
self_attn
]
*
len
(
c_hidden
)
# CONDITIONING
self
.
effnet_mapper
=
nn
.
Sequential
(
operations
.
Conv2d
(
c_effnet
,
c_hidden
[
0
]
*
4
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
operations
.
Conv2d
(
c_hidden
[
0
]
*
4
,
c_hidden
[
0
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
)
self
.
pixels_mapper
=
nn
.
Sequential
(
operations
.
Conv2d
(
c_pixels
,
c_hidden
[
0
]
*
4
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
operations
.
Conv2d
(
c_hidden
[
0
]
*
4
,
c_hidden
[
0
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
)
self
.
clip_mapper
=
operations
.
Linear
(
c_clip
,
c_cond
*
c_clip_seq
,
dtype
=
dtype
,
device
=
device
)
self
.
clip_norm
=
operations
.
LayerNorm
(
c_cond
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
embedding
=
nn
.
Sequential
(
nn
.
PixelUnshuffle
(
patch_size
),
operations
.
Conv2d
(
c_in
*
(
patch_size
**
2
),
c_hidden
[
0
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
)
def
get_block
(
block_type
,
c_hidden
,
nhead
,
c_skip
=
0
,
dropout
=
0
,
self_attn
=
True
):
if
block_type
==
'C'
:
return
ResBlock
(
c_hidden
,
c_skip
,
kernel_size
=
kernel_size
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'A'
:
return
AttnBlock
(
c_hidden
,
c_cond
,
nhead
,
self_attn
=
self_attn
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'F'
:
return
FeedForwardBlock
(
c_hidden
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'T'
:
return
TimestepBlock
(
c_hidden
,
c_r
,
conds
=
t_conds
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
else
:
raise
Exception
(
f
'Block type
{
block_type
}
not supported'
)
# BLOCKS
# -- down blocks
self
.
down_blocks
=
nn
.
ModuleList
()
self
.
down_downscalers
=
nn
.
ModuleList
()
self
.
down_repeat_mappers
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
c_hidden
)):
if
i
>
0
:
self
.
down_downscalers
.
append
(
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
i
-
1
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
i
-
1
],
c_hidden
[
i
],
kernel_size
=
2
,
stride
=
2
,
dtype
=
dtype
,
device
=
device
),
))
else
:
self
.
down_downscalers
.
append
(
nn
.
Identity
())
down_block
=
nn
.
ModuleList
()
for
_
in
range
(
blocks
[
0
][
i
]):
for
block_type
in
level_config
[
i
]:
block
=
get_block
(
block_type
,
c_hidden
[
i
],
nhead
[
i
],
dropout
=
dropout
[
i
],
self_attn
=
self_attn
[
i
])
down_block
.
append
(
block
)
self
.
down_blocks
.
append
(
down_block
)
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
0
][
i
]
-
1
):
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
down_repeat_mappers
.
append
(
block_repeat_mappers
)
# -- up blocks
self
.
up_blocks
=
nn
.
ModuleList
()
self
.
up_upscalers
=
nn
.
ModuleList
()
self
.
up_repeat_mappers
=
nn
.
ModuleList
()
for
i
in
reversed
(
range
(
len
(
c_hidden
))):
if
i
>
0
:
self
.
up_upscalers
.
append
(
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
i
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
ConvTranspose2d
(
c_hidden
[
i
],
c_hidden
[
i
-
1
],
kernel_size
=
2
,
stride
=
2
,
dtype
=
dtype
,
device
=
device
),
))
else
:
self
.
up_upscalers
.
append
(
nn
.
Identity
())
up_block
=
nn
.
ModuleList
()
for
j
in
range
(
blocks
[
1
][::
-
1
][
i
]):
for
k
,
block_type
in
enumerate
(
level_config
[
i
]):
c_skip
=
c_hidden
[
i
]
if
i
<
len
(
c_hidden
)
-
1
and
j
==
k
==
0
else
0
block
=
get_block
(
block_type
,
c_hidden
[
i
],
nhead
[
i
],
c_skip
=
c_skip
,
dropout
=
dropout
[
i
],
self_attn
=
self_attn
[
i
])
up_block
.
append
(
block
)
self
.
up_blocks
.
append
(
up_block
)
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
1
][::
-
1
][
i
]
-
1
):
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
up_repeat_mappers
.
append
(
block_repeat_mappers
)
# OUTPUT
self
.
clf
=
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
0
],
c_out
*
(
patch_size
**
2
),
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
PixelShuffle
(
patch_size
),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def
gen_r_embedding
(
self
,
r
,
max_positions
=
10000
):
r
=
r
*
max_positions
half_dim
=
self
.
c_r
//
2
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
emb
=
torch
.
arange
(
half_dim
,
device
=
r
.
device
).
float
().
mul
(
-
emb
).
exp
()
emb
=
r
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
emb
.
sin
(),
emb
.
cos
()],
dim
=
1
)
if
self
.
c_r
%
2
==
1
:
# zero pad
emb
=
nn
.
functional
.
pad
(
emb
,
(
0
,
1
),
mode
=
'constant'
)
return
emb
def
gen_c_embeddings
(
self
,
clip
):
if
len
(
clip
.
shape
)
==
2
:
clip
=
clip
.
unsqueeze
(
1
)
clip
=
self
.
clip_mapper
(
clip
).
view
(
clip
.
size
(
0
),
clip
.
size
(
1
)
*
self
.
c_clip_seq
,
-
1
)
clip
=
self
.
clip_norm
(
clip
)
return
clip
def
_down_encode
(
self
,
x
,
r_embed
,
clip
):
level_outputs
=
[]
block_group
=
zip
(
self
.
down_blocks
,
self
.
down_downscalers
,
self
.
down_repeat_mappers
)
for
down_block
,
downscaler
,
repmap
in
block_group
:
x
=
downscaler
(
x
)
for
i
in
range
(
len
(
repmap
)
+
1
):
for
block
in
down_block
:
if
isinstance
(
block
,
ResBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
ResBlock
)):
x
=
block
(
x
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
AttnBlock
)):
x
=
block
(
x
,
clip
)
elif
isinstance
(
block
,
TimestepBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
TimestepBlock
)):
x
=
block
(
x
,
r_embed
)
else
:
x
=
block
(
x
)
if
i
<
len
(
repmap
):
x
=
repmap
[
i
](
x
)
level_outputs
.
insert
(
0
,
x
)
return
level_outputs
def
_up_decode
(
self
,
level_outputs
,
r_embed
,
clip
):
x
=
level_outputs
[
0
]
block_group
=
zip
(
self
.
up_blocks
,
self
.
up_upscalers
,
self
.
up_repeat_mappers
)
for
i
,
(
up_block
,
upscaler
,
repmap
)
in
enumerate
(
block_group
):
for
j
in
range
(
len
(
repmap
)
+
1
):
for
k
,
block
in
enumerate
(
up_block
):
if
isinstance
(
block
,
ResBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
ResBlock
)):
skip
=
level_outputs
[
i
]
if
k
==
0
and
i
>
0
else
None
if
skip
is
not
None
and
(
x
.
size
(
-
1
)
!=
skip
.
size
(
-
1
)
or
x
.
size
(
-
2
)
!=
skip
.
size
(
-
2
)):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
skip
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
x
=
block
(
x
,
skip
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
AttnBlock
)):
x
=
block
(
x
,
clip
)
elif
isinstance
(
block
,
TimestepBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
TimestepBlock
)):
x
=
block
(
x
,
r_embed
)
else
:
x
=
block
(
x
)
if
j
<
len
(
repmap
):
x
=
repmap
[
j
](
x
)
x
=
upscaler
(
x
)
return
x
def
forward
(
self
,
x
,
r
,
effnet
,
clip
,
pixels
=
None
,
**
kwargs
):
if
pixels
is
None
:
pixels
=
x
.
new_zeros
(
x
.
size
(
0
),
3
,
8
,
8
)
# Process the conditioning embeddings
r_embed
=
self
.
gen_r_embedding
(
r
).
to
(
dtype
=
x
.
dtype
)
for
c
in
self
.
t_conds
:
t_cond
=
kwargs
.
get
(
c
,
torch
.
zeros_like
(
r
))
r_embed
=
torch
.
cat
([
r_embed
,
self
.
gen_r_embedding
(
t_cond
).
to
(
dtype
=
x
.
dtype
)],
dim
=
1
)
clip
=
self
.
gen_c_embeddings
(
clip
)
# Model Blocks
x
=
self
.
embedding
(
x
)
x
=
x
+
self
.
effnet_mapper
(
nn
.
functional
.
interpolate
(
effnet
,
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
))
x
=
x
+
nn
.
functional
.
interpolate
(
self
.
pixels_mapper
(
pixels
),
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
level_outputs
=
self
.
_down_encode
(
x
,
r_embed
,
clip
)
x
=
self
.
_up_decode
(
level_outputs
,
r_embed
,
clip
)
return
self
.
clf
(
x
)
def
update_weights_ema
(
self
,
src_model
,
beta
=
0.999
):
for
self_params
,
src_params
in
zip
(
self
.
parameters
(),
src_model
.
parameters
()):
self_params
.
data
=
self_params
.
data
*
beta
+
src_params
.
data
.
clone
().
to
(
self_params
.
device
)
*
(
1
-
beta
)
for
self_buffers
,
src_buffers
in
zip
(
self
.
buffers
(),
src_model
.
buffers
()):
self_buffers
.
data
=
self_buffers
.
data
*
beta
+
src_buffers
.
data
.
clone
().
to
(
self_buffers
.
device
)
*
(
1
-
beta
)
comfy/ldm/cascade/stage_c.py
View file @
667c9281
...
@@ -42,7 +42,7 @@ class StageC(nn.Module):
...
@@ -42,7 +42,7 @@ class StageC(nn.Module):
def
__init__
(
self
,
c_in
=
16
,
c_out
=
16
,
c_r
=
64
,
patch_size
=
1
,
c_cond
=
2048
,
c_hidden
=
[
2048
,
2048
],
nhead
=
[
32
,
32
],
def
__init__
(
self
,
c_in
=
16
,
c_out
=
16
,
c_r
=
64
,
patch_size
=
1
,
c_cond
=
2048
,
c_hidden
=
[
2048
,
2048
],
nhead
=
[
32
,
32
],
blocks
=
[[
8
,
24
],
[
24
,
8
]],
block_repeat
=
[[
1
,
1
],
[
1
,
1
]],
level_config
=
[
'CTA'
,
'CTA'
],
blocks
=
[[
8
,
24
],
[
24
,
8
]],
block_repeat
=
[[
1
,
1
],
[
1
,
1
]],
level_config
=
[
'CTA'
,
'CTA'
],
c_clip_text
=
1280
,
c_clip_text_pooled
=
1280
,
c_clip_img
=
768
,
c_clip_seq
=
4
,
kernel_size
=
3
,
c_clip_text
=
1280
,
c_clip_text_pooled
=
1280
,
c_clip_img
=
768
,
c_clip_seq
=
4
,
kernel_size
=
3
,
dropout
=
[
0.
1
,
0.
1
],
self_attn
=
True
,
t_conds
=
[
'sca'
,
'crp'
],
switch_level
=
[
False
],
stable_cascade_stage
=
None
,
dropout
=
[
0.
0
,
0.
0
],
self_attn
=
True
,
t_conds
=
[
'sca'
,
'crp'
],
switch_level
=
[
False
],
stable_cascade_stage
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -100,7 +100,7 @@ class StageC(nn.Module):
...
@@ -100,7 +100,7 @@ class StageC(nn.Module):
if
block_repeat
is
not
None
:
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
0
][
i
]
-
1
):
for
_
in
range
(
block_repeat
[
0
][
i
]
-
1
):
block_repeat_mappers
.
append
(
nn
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
))
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
down_repeat_mappers
.
append
(
block_repeat_mappers
)
self
.
down_repeat_mappers
.
append
(
block_repeat_mappers
)
# -- up blocks
# -- up blocks
...
@@ -126,12 +126,12 @@ class StageC(nn.Module):
...
@@ -126,12 +126,12 @@ class StageC(nn.Module):
if
block_repeat
is
not
None
:
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
1
][::
-
1
][
i
]
-
1
):
for
_
in
range
(
block_repeat
[
1
][::
-
1
][
i
]
-
1
):
block_repeat_mappers
.
append
(
nn
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
))
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
up_repeat_mappers
.
append
(
block_repeat_mappers
)
self
.
up_repeat_mappers
.
append
(
block_repeat_mappers
)
# OUTPUT
# OUTPUT
self
.
clf
=
nn
.
Sequential
(
self
.
clf
=
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
0
],
c_out
*
(
patch_size
**
2
),
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
0
],
c_out
*
(
patch_size
**
2
),
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
PixelShuffle
(
patch_size
),
nn
.
PixelShuffle
(
patch_size
),
)
)
...
...
comfy/model_base.py
View file @
667c9281
import
torch
import
torch
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
Timestep
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
Timestep
from
comfy.ldm.cascade.stage_c
import
StageC
from
comfy.ldm.cascade.stage_c
import
StageC
from
comfy.ldm.cascade.stage_b
import
StageB
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.upscaling
import
ImageConcatWithNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.upscaling
import
ImageConcatWithNoiseAugmentation
import
comfy.model_management
import
comfy.model_management
...
@@ -461,3 +462,27 @@ class StableCascade_C(BaseModel):
...
@@ -461,3 +462,27 @@ class StableCascade_C(BaseModel):
out
[
'clip_text'
]
=
comfy
.
conds
.
CONDCrossAttn
(
cross_attn
)
out
[
'clip_text'
]
=
comfy
.
conds
.
CONDCrossAttn
(
cross_attn
)
return
out
return
out
class
StableCascade_B
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
STABLE_CASCADE
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
,
unet_model
=
StageB
)
self
.
diffusion_model
.
eval
().
requires_grad_
(
False
)
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
noise
=
kwargs
.
get
(
"noise"
,
None
)
clip_text_pooled
=
kwargs
[
"pooled_output"
]
if
clip_text_pooled
is
not
None
:
out
[
'clip_text_pooled'
]
=
comfy
.
conds
.
CONDRegular
(
clip_text_pooled
)
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior
=
kwargs
.
get
(
"stable_cascade_prior"
,
torch
.
zeros
((
1
,
16
,
(
noise
.
shape
[
2
]
*
4
)
//
42
,
(
noise
.
shape
[
3
]
*
4
)
//
42
),
dtype
=
noise
.
dtype
,
layout
=
noise
.
layout
,
device
=
noise
.
device
))
out
[
"effnet"
]
=
comfy
.
conds
.
CONDRegular
(
prior
)
out
[
"sca"
]
=
comfy
.
conds
.
CONDRegular
(
torch
.
zeros
((
1
,)))
cross_attn
=
kwargs
.
get
(
"cross_attn"
,
None
)
if
cross_attn
is
not
None
:
out
[
'clip'
]
=
comfy
.
conds
.
CONDCrossAttn
(
cross_attn
)
return
out
comfy/ops.py
View file @
667c9281
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
import
torch
import
comfy.model_management
import
comfy.model_management
...
@@ -78,7 +96,11 @@ class disable_weight_init:
...
@@ -78,7 +96,11 @@ class disable_weight_init:
return
None
return
None
def
forward_comfy_cast_weights
(
self
,
input
):
def
forward_comfy_cast_weights
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
if
self
.
weight
is
not
None
:
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
else
:
weight
=
None
bias
=
None
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
...
@@ -87,6 +109,28 @@ class disable_weight_init:
...
@@ -87,6 +109,28 @@ class disable_weight_init:
else
:
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
class
ConvTranspose2d
(
torch
.
nn
.
ConvTranspose2d
):
comfy_cast_weights
=
False
def
reset_parameters
(
self
):
return
None
def
forward_comfy_cast_weights
(
self
,
input
,
output_size
=
None
):
num_spatial_dims
=
2
output_padding
=
self
.
_output_padding
(
input
,
output_size
,
self
.
stride
,
self
.
padding
,
self
.
kernel_size
,
num_spatial_dims
,
self
.
dilation
)
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
conv_transpose2d
(
input
,
weight
,
bias
,
self
.
stride
,
self
.
padding
,
output_padding
,
self
.
groups
,
self
.
dilation
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
comfy_cast_weights
:
return
self
.
forward_comfy_cast_weights
(
*
args
,
**
kwargs
)
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
@
classmethod
@
classmethod
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
if
dims
==
2
:
if
dims
==
2
:
...
@@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
...
@@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
class
LayerNorm
(
disable_weight_init
.
LayerNorm
):
class
LayerNorm
(
disable_weight_init
.
LayerNorm
):
comfy_cast_weights
=
True
comfy_cast_weights
=
True
class
ConvTranspose2d
(
disable_weight_init
.
ConvTranspose2d
):
comfy_cast_weights
=
True
comfy/supported_models.py
View file @
667c9281
...
@@ -338,6 +338,20 @@ class Stable_Cascade_C(supported_models_base.BASE):
...
@@ -338,6 +338,20 @@ class Stable_Cascade_C(supported_models_base.BASE):
def
clip_target
(
self
):
def
clip_target
(
self
):
return
None
return
None
class
Stable_Cascade_B
(
Stable_Cascade_C
):
unet_config
=
{
"stable_cascade_stage"
:
'b'
,
}
unet_extra_config
=
{}
latent_format
=
latent_formats
.
SC_B
supported_inference_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
StableCascade_B
(
self
,
device
=
device
)
return
out
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
]
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
,
Stable_Cascade_B
]
models
+=
[
SVD_img2vid
]
models
+=
[
SVD_img2vid
]
comfy/utils.py
View file @
667c9281
...
@@ -169,6 +169,8 @@ UNET_MAP_BASIC = {
...
@@ -169,6 +169,8 @@ UNET_MAP_BASIC = {
}
}
def
unet_to_diffusers
(
unet_config
):
def
unet_to_diffusers
(
unet_config
):
if
"num_res_blocks"
not
in
unet_config
:
return
{}
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
channel_mult
=
unet_config
[
"channel_mult"
]
channel_mult
=
unet_config
[
"channel_mult"
]
transformer_depth
=
unet_config
[
"transformer_depth"
][:]
transformer_depth
=
unet_config
[
"transformer_depth"
][:]
...
...
comfy_extras/nodes_stable_cascade.py
0 → 100644
View file @
667c9281
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
import
nodes
class
StableCascade_EmptyLatentImage
:
def
__init__
(
self
,
device
=
"cpu"
):
self
.
device
=
device
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"width"
:
(
"INT"
,
{
"default"
:
1024
,
"min"
:
256
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"height"
:
(
"INT"
,
{
"default"
:
1024
,
"min"
:
256
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"compression"
:
(
"INT"
,
{
"default"
:
42
,
"min"
:
32
,
"max"
:
64
,
"step"
:
1
}),
"batch_size"
:
(
"INT"
,
{
"default"
:
1
,
"min"
:
1
,
"max"
:
64
})
}}
RETURN_TYPES
=
(
"LATENT"
,
"LATENT"
)
RETURN_NAMES
=
(
"stage_c"
,
"stage_b"
)
FUNCTION
=
"generate"
CATEGORY
=
"_for_testing/stable_cascade"
def
generate
(
self
,
width
,
height
,
compression
,
batch_size
=
1
):
c_latent
=
torch
.
zeros
([
batch_size
,
16
,
height
//
compression
,
width
//
compression
])
b_latent
=
torch
.
zeros
([
batch_size
,
4
,
height
//
4
,
width
//
4
])
return
({
"samples"
:
c_latent
,
},
{
"samples"
:
b_latent
,
})
class
StableCascade_StageB_Conditioning
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"conditioning"
:
(
"CONDITIONING"
,),
"stage_c"
:
(
"LATENT"
,),
}}
RETURN_TYPES
=
(
"CONDITIONING"
,)
FUNCTION
=
"set_prior"
CATEGORY
=
"_for_testing/stable_cascade"
def
set_prior
(
self
,
conditioning
,
stage_c
):
c
=
[]
for
t
in
conditioning
:
d
=
t
[
1
].
copy
()
d
[
'stable_cascade_prior'
]
=
stage_c
[
'samples'
]
n
=
[
t
[
0
],
d
]
c
.
append
(
n
)
return
(
c
,
)
NODE_CLASS_MAPPINGS
=
{
"StableCascade_EmptyLatentImage"
:
StableCascade_EmptyLatentImage
,
"StableCascade_StageB_Conditioning"
:
StableCascade_StageB_Conditioning
,
}
nodes.py
View file @
667c9281
...
@@ -1967,6 +1967,7 @@ def init_custom_nodes():
...
@@ -1967,6 +1967,7 @@ def init_custom_nodes():
"nodes_sdupscale.py"
,
"nodes_sdupscale.py"
,
"nodes_photomaker.py"
,
"nodes_photomaker.py"
,
"nodes_cond.py"
,
"nodes_cond.py"
,
"nodes_stable_cascade.py"
,
]
]
for
node_file
in
extras_files
:
for
node_file
in
extras_files
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment