Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
667c9281
"vscode:/vscode.git/clone" did not exist on "c5ff469d0ea7161c6166d4bad9741b60725baf3f"
Commit
667c9281
authored
Feb 16, 2024
by
comfyanonymous
Browse files
Stable Cascade Stage B.
parent
f83109f0
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
430 additions
and
8 deletions
+430
-8
comfy/latent_formats.py
comfy/latent_formats.py
+3
-1
comfy/ldm/cascade/common.py
comfy/ldm/cascade/common.py
+1
-1
comfy/ldm/cascade/stage_b.py
comfy/ldm/cascade/stage_b.py
+257
-0
comfy/ldm/cascade/stage_c.py
comfy/ldm/cascade/stage_c.py
+4
-4
comfy/model_base.py
comfy/model_base.py
+25
-0
comfy/ops.py
comfy/ops.py
+48
-1
comfy/supported_models.py
comfy/supported_models.py
+15
-1
comfy/utils.py
comfy/utils.py
+2
-0
comfy_extras/nodes_stable_cascade.py
comfy_extras/nodes_stable_cascade.py
+74
-0
nodes.py
nodes.py
+1
-0
No files found.
comfy/latent_formats.py
View file @
667c9281
...
@@ -42,4 +42,6 @@ class SC_Prior(LatentFormat):
...
@@ -42,4 +42,6 @@ class SC_Prior(LatentFormat):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
scale_factor
=
1.0
self
.
scale_factor
=
1.0
class
SC_B
(
LatentFormat
):
def
__init__
(
self
):
self
.
scale_factor
=
1.0
comfy/ldm/cascade/common.py
View file @
667c9281
...
@@ -84,7 +84,7 @@ class GlobalResponseNorm(nn.Module):
...
@@ -84,7 +84,7 @@ class GlobalResponseNorm(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Gx
=
torch
.
norm
(
x
,
p
=
2
,
dim
=
(
1
,
2
),
keepdim
=
True
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
Nx
=
Gx
/
(
Gx
.
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
return
self
.
gamma
*
(
x
*
Nx
)
+
self
.
beta
+
x
return
self
.
gamma
.
to
(
x
.
device
)
*
(
x
*
Nx
)
+
self
.
beta
.
to
(
x
.
device
)
+
x
class
ResBlock
(
nn
.
Module
):
class
ResBlock
(
nn
.
Module
):
...
...
comfy/ldm/cascade/stage_b.py
0 → 100644
View file @
667c9281
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
.common
import
AttnBlock
,
LayerNorm2d_op
,
ResBlock
,
FeedForwardBlock
,
TimestepBlock
class
StageB
(
nn
.
Module
):
def
__init__
(
self
,
c_in
=
4
,
c_out
=
4
,
c_r
=
64
,
patch_size
=
2
,
c_cond
=
1280
,
c_hidden
=
[
320
,
640
,
1280
,
1280
],
nhead
=
[
-
1
,
-
1
,
20
,
20
],
blocks
=
[[
2
,
6
,
28
,
6
],
[
6
,
28
,
6
,
2
]],
block_repeat
=
[[
1
,
1
,
1
,
1
],
[
3
,
3
,
2
,
2
]],
level_config
=
[
'CT'
,
'CT'
,
'CTA'
,
'CTA'
],
c_clip
=
1280
,
c_clip_seq
=
4
,
c_effnet
=
16
,
c_pixels
=
3
,
kernel_size
=
3
,
dropout
=
[
0
,
0
,
0.0
,
0.0
],
self_attn
=
True
,
t_conds
=
[
'sca'
],
stable_cascade_stage
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
self
.
dtype
=
dtype
self
.
c_r
=
c_r
self
.
t_conds
=
t_conds
self
.
c_clip_seq
=
c_clip_seq
if
not
isinstance
(
dropout
,
list
):
dropout
=
[
dropout
]
*
len
(
c_hidden
)
if
not
isinstance
(
self_attn
,
list
):
self_attn
=
[
self_attn
]
*
len
(
c_hidden
)
# CONDITIONING
self
.
effnet_mapper
=
nn
.
Sequential
(
operations
.
Conv2d
(
c_effnet
,
c_hidden
[
0
]
*
4
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
operations
.
Conv2d
(
c_hidden
[
0
]
*
4
,
c_hidden
[
0
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
)
self
.
pixels_mapper
=
nn
.
Sequential
(
operations
.
Conv2d
(
c_pixels
,
c_hidden
[
0
]
*
4
,
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
GELU
(),
operations
.
Conv2d
(
c_hidden
[
0
]
*
4
,
c_hidden
[
0
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
)
self
.
clip_mapper
=
operations
.
Linear
(
c_clip
,
c_cond
*
c_clip_seq
,
dtype
=
dtype
,
device
=
device
)
self
.
clip_norm
=
operations
.
LayerNorm
(
c_cond
,
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
embedding
=
nn
.
Sequential
(
nn
.
PixelUnshuffle
(
patch_size
),
operations
.
Conv2d
(
c_in
*
(
patch_size
**
2
),
c_hidden
[
0
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
)
def
get_block
(
block_type
,
c_hidden
,
nhead
,
c_skip
=
0
,
dropout
=
0
,
self_attn
=
True
):
if
block_type
==
'C'
:
return
ResBlock
(
c_hidden
,
c_skip
,
kernel_size
=
kernel_size
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'A'
:
return
AttnBlock
(
c_hidden
,
c_cond
,
nhead
,
self_attn
=
self_attn
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'F'
:
return
FeedForwardBlock
(
c_hidden
,
dropout
=
dropout
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
elif
block_type
==
'T'
:
return
TimestepBlock
(
c_hidden
,
c_r
,
conds
=
t_conds
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
else
:
raise
Exception
(
f
'Block type
{
block_type
}
not supported'
)
# BLOCKS
# -- down blocks
self
.
down_blocks
=
nn
.
ModuleList
()
self
.
down_downscalers
=
nn
.
ModuleList
()
self
.
down_repeat_mappers
=
nn
.
ModuleList
()
for
i
in
range
(
len
(
c_hidden
)):
if
i
>
0
:
self
.
down_downscalers
.
append
(
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
i
-
1
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
i
-
1
],
c_hidden
[
i
],
kernel_size
=
2
,
stride
=
2
,
dtype
=
dtype
,
device
=
device
),
))
else
:
self
.
down_downscalers
.
append
(
nn
.
Identity
())
down_block
=
nn
.
ModuleList
()
for
_
in
range
(
blocks
[
0
][
i
]):
for
block_type
in
level_config
[
i
]:
block
=
get_block
(
block_type
,
c_hidden
[
i
],
nhead
[
i
],
dropout
=
dropout
[
i
],
self_attn
=
self_attn
[
i
])
down_block
.
append
(
block
)
self
.
down_blocks
.
append
(
down_block
)
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
0
][
i
]
-
1
):
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
down_repeat_mappers
.
append
(
block_repeat_mappers
)
# -- up blocks
self
.
up_blocks
=
nn
.
ModuleList
()
self
.
up_upscalers
=
nn
.
ModuleList
()
self
.
up_repeat_mappers
=
nn
.
ModuleList
()
for
i
in
reversed
(
range
(
len
(
c_hidden
))):
if
i
>
0
:
self
.
up_upscalers
.
append
(
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
i
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
ConvTranspose2d
(
c_hidden
[
i
],
c_hidden
[
i
-
1
],
kernel_size
=
2
,
stride
=
2
,
dtype
=
dtype
,
device
=
device
),
))
else
:
self
.
up_upscalers
.
append
(
nn
.
Identity
())
up_block
=
nn
.
ModuleList
()
for
j
in
range
(
blocks
[
1
][::
-
1
][
i
]):
for
k
,
block_type
in
enumerate
(
level_config
[
i
]):
c_skip
=
c_hidden
[
i
]
if
i
<
len
(
c_hidden
)
-
1
and
j
==
k
==
0
else
0
block
=
get_block
(
block_type
,
c_hidden
[
i
],
nhead
[
i
],
c_skip
=
c_skip
,
dropout
=
dropout
[
i
],
self_attn
=
self_attn
[
i
])
up_block
.
append
(
block
)
self
.
up_blocks
.
append
(
up_block
)
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
1
][::
-
1
][
i
]
-
1
):
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
up_repeat_mappers
.
append
(
block_repeat_mappers
)
# OUTPUT
self
.
clf
=
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
0
],
c_out
*
(
patch_size
**
2
),
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
PixelShuffle
(
patch_size
),
)
# --- WEIGHT INIT ---
# self.apply(self._init_weights) # General init
# nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
# nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
# torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
# nn.init.constant_(self.clf[1].weight, 0) # outputs
#
# # blocks
# for level_block in self.down_blocks + self.up_blocks:
# for block in level_block:
# if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
# block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
# elif isinstance(block, TimestepBlock):
# for layer in block.modules():
# if isinstance(layer, nn.Linear):
# nn.init.constant_(layer.weight, 0)
#
# def _init_weights(self, m):
# if isinstance(m, (nn.Conv2d, nn.Linear)):
# torch.nn.init.xavier_uniform_(m.weight)
# if m.bias is not None:
# nn.init.constant_(m.bias, 0)
def
gen_r_embedding
(
self
,
r
,
max_positions
=
10000
):
r
=
r
*
max_positions
half_dim
=
self
.
c_r
//
2
emb
=
math
.
log
(
max_positions
)
/
(
half_dim
-
1
)
emb
=
torch
.
arange
(
half_dim
,
device
=
r
.
device
).
float
().
mul
(
-
emb
).
exp
()
emb
=
r
[:,
None
]
*
emb
[
None
,
:]
emb
=
torch
.
cat
([
emb
.
sin
(),
emb
.
cos
()],
dim
=
1
)
if
self
.
c_r
%
2
==
1
:
# zero pad
emb
=
nn
.
functional
.
pad
(
emb
,
(
0
,
1
),
mode
=
'constant'
)
return
emb
def
gen_c_embeddings
(
self
,
clip
):
if
len
(
clip
.
shape
)
==
2
:
clip
=
clip
.
unsqueeze
(
1
)
clip
=
self
.
clip_mapper
(
clip
).
view
(
clip
.
size
(
0
),
clip
.
size
(
1
)
*
self
.
c_clip_seq
,
-
1
)
clip
=
self
.
clip_norm
(
clip
)
return
clip
def
_down_encode
(
self
,
x
,
r_embed
,
clip
):
level_outputs
=
[]
block_group
=
zip
(
self
.
down_blocks
,
self
.
down_downscalers
,
self
.
down_repeat_mappers
)
for
down_block
,
downscaler
,
repmap
in
block_group
:
x
=
downscaler
(
x
)
for
i
in
range
(
len
(
repmap
)
+
1
):
for
block
in
down_block
:
if
isinstance
(
block
,
ResBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
ResBlock
)):
x
=
block
(
x
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
AttnBlock
)):
x
=
block
(
x
,
clip
)
elif
isinstance
(
block
,
TimestepBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
TimestepBlock
)):
x
=
block
(
x
,
r_embed
)
else
:
x
=
block
(
x
)
if
i
<
len
(
repmap
):
x
=
repmap
[
i
](
x
)
level_outputs
.
insert
(
0
,
x
)
return
level_outputs
def
_up_decode
(
self
,
level_outputs
,
r_embed
,
clip
):
x
=
level_outputs
[
0
]
block_group
=
zip
(
self
.
up_blocks
,
self
.
up_upscalers
,
self
.
up_repeat_mappers
)
for
i
,
(
up_block
,
upscaler
,
repmap
)
in
enumerate
(
block_group
):
for
j
in
range
(
len
(
repmap
)
+
1
):
for
k
,
block
in
enumerate
(
up_block
):
if
isinstance
(
block
,
ResBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
ResBlock
)):
skip
=
level_outputs
[
i
]
if
k
==
0
and
i
>
0
else
None
if
skip
is
not
None
and
(
x
.
size
(
-
1
)
!=
skip
.
size
(
-
1
)
or
x
.
size
(
-
2
)
!=
skip
.
size
(
-
2
)):
x
=
torch
.
nn
.
functional
.
interpolate
(
x
,
skip
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
x
=
block
(
x
,
skip
)
elif
isinstance
(
block
,
AttnBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
AttnBlock
)):
x
=
block
(
x
,
clip
)
elif
isinstance
(
block
,
TimestepBlock
)
or
(
hasattr
(
block
,
'_fsdp_wrapped_module'
)
and
isinstance
(
block
.
_fsdp_wrapped_module
,
TimestepBlock
)):
x
=
block
(
x
,
r_embed
)
else
:
x
=
block
(
x
)
if
j
<
len
(
repmap
):
x
=
repmap
[
j
](
x
)
x
=
upscaler
(
x
)
return
x
def
forward
(
self
,
x
,
r
,
effnet
,
clip
,
pixels
=
None
,
**
kwargs
):
if
pixels
is
None
:
pixels
=
x
.
new_zeros
(
x
.
size
(
0
),
3
,
8
,
8
)
# Process the conditioning embeddings
r_embed
=
self
.
gen_r_embedding
(
r
).
to
(
dtype
=
x
.
dtype
)
for
c
in
self
.
t_conds
:
t_cond
=
kwargs
.
get
(
c
,
torch
.
zeros_like
(
r
))
r_embed
=
torch
.
cat
([
r_embed
,
self
.
gen_r_embedding
(
t_cond
).
to
(
dtype
=
x
.
dtype
)],
dim
=
1
)
clip
=
self
.
gen_c_embeddings
(
clip
)
# Model Blocks
x
=
self
.
embedding
(
x
)
x
=
x
+
self
.
effnet_mapper
(
nn
.
functional
.
interpolate
(
effnet
,
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
))
x
=
x
+
nn
.
functional
.
interpolate
(
self
.
pixels_mapper
(
pixels
),
size
=
x
.
shape
[
-
2
:],
mode
=
'bilinear'
,
align_corners
=
True
)
level_outputs
=
self
.
_down_encode
(
x
,
r_embed
,
clip
)
x
=
self
.
_up_decode
(
level_outputs
,
r_embed
,
clip
)
return
self
.
clf
(
x
)
def
update_weights_ema
(
self
,
src_model
,
beta
=
0.999
):
for
self_params
,
src_params
in
zip
(
self
.
parameters
(),
src_model
.
parameters
()):
self_params
.
data
=
self_params
.
data
*
beta
+
src_params
.
data
.
clone
().
to
(
self_params
.
device
)
*
(
1
-
beta
)
for
self_buffers
,
src_buffers
in
zip
(
self
.
buffers
(),
src_model
.
buffers
()):
self_buffers
.
data
=
self_buffers
.
data
*
beta
+
src_buffers
.
data
.
clone
().
to
(
self_buffers
.
device
)
*
(
1
-
beta
)
comfy/ldm/cascade/stage_c.py
View file @
667c9281
...
@@ -42,7 +42,7 @@ class StageC(nn.Module):
...
@@ -42,7 +42,7 @@ class StageC(nn.Module):
def
__init__
(
self
,
c_in
=
16
,
c_out
=
16
,
c_r
=
64
,
patch_size
=
1
,
c_cond
=
2048
,
c_hidden
=
[
2048
,
2048
],
nhead
=
[
32
,
32
],
def
__init__
(
self
,
c_in
=
16
,
c_out
=
16
,
c_r
=
64
,
patch_size
=
1
,
c_cond
=
2048
,
c_hidden
=
[
2048
,
2048
],
nhead
=
[
32
,
32
],
blocks
=
[[
8
,
24
],
[
24
,
8
]],
block_repeat
=
[[
1
,
1
],
[
1
,
1
]],
level_config
=
[
'CTA'
,
'CTA'
],
blocks
=
[[
8
,
24
],
[
24
,
8
]],
block_repeat
=
[[
1
,
1
],
[
1
,
1
]],
level_config
=
[
'CTA'
,
'CTA'
],
c_clip_text
=
1280
,
c_clip_text_pooled
=
1280
,
c_clip_img
=
768
,
c_clip_seq
=
4
,
kernel_size
=
3
,
c_clip_text
=
1280
,
c_clip_text_pooled
=
1280
,
c_clip_img
=
768
,
c_clip_seq
=
4
,
kernel_size
=
3
,
dropout
=
[
0.
1
,
0.
1
],
self_attn
=
True
,
t_conds
=
[
'sca'
,
'crp'
],
switch_level
=
[
False
],
stable_cascade_stage
=
None
,
dropout
=
[
0.
0
,
0.
0
],
self_attn
=
True
,
t_conds
=
[
'sca'
,
'crp'
],
switch_level
=
[
False
],
stable_cascade_stage
=
None
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
dtype
=
None
,
device
=
None
,
operations
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -100,7 +100,7 @@ class StageC(nn.Module):
...
@@ -100,7 +100,7 @@ class StageC(nn.Module):
if
block_repeat
is
not
None
:
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
0
][
i
]
-
1
):
for
_
in
range
(
block_repeat
[
0
][
i
]
-
1
):
block_repeat_mappers
.
append
(
nn
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
))
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
down_repeat_mappers
.
append
(
block_repeat_mappers
)
self
.
down_repeat_mappers
.
append
(
block_repeat_mappers
)
# -- up blocks
# -- up blocks
...
@@ -126,12 +126,12 @@ class StageC(nn.Module):
...
@@ -126,12 +126,12 @@ class StageC(nn.Module):
if
block_repeat
is
not
None
:
if
block_repeat
is
not
None
:
block_repeat_mappers
=
nn
.
ModuleList
()
block_repeat_mappers
=
nn
.
ModuleList
()
for
_
in
range
(
block_repeat
[
1
][::
-
1
][
i
]
-
1
):
for
_
in
range
(
block_repeat
[
1
][::
-
1
][
i
]
-
1
):
block_repeat_mappers
.
append
(
nn
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
))
block_repeat_mappers
.
append
(
operations
.
Conv2d
(
c_hidden
[
i
],
c_hidden
[
i
],
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
))
self
.
up_repeat_mappers
.
append
(
block_repeat_mappers
)
self
.
up_repeat_mappers
.
append
(
block_repeat_mappers
)
# OUTPUT
# OUTPUT
self
.
clf
=
nn
.
Sequential
(
self
.
clf
=
nn
.
Sequential
(
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
),
LayerNorm2d_op
(
operations
)(
c_hidden
[
0
],
elementwise_affine
=
False
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
0
],
c_out
*
(
patch_size
**
2
),
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
operations
.
Conv2d
(
c_hidden
[
0
],
c_out
*
(
patch_size
**
2
),
kernel_size
=
1
,
dtype
=
dtype
,
device
=
device
),
nn
.
PixelShuffle
(
patch_size
),
nn
.
PixelShuffle
(
patch_size
),
)
)
...
...
comfy/model_base.py
View file @
667c9281
import
torch
import
torch
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
Timestep
from
comfy.ldm.modules.diffusionmodules.openaimodel
import
UNetModel
,
Timestep
from
comfy.ldm.cascade.stage_c
import
StageC
from
comfy.ldm.cascade.stage_c
import
StageC
from
comfy.ldm.cascade.stage_b
import
StageB
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.encoders.noise_aug_modules
import
CLIPEmbeddingNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.upscaling
import
ImageConcatWithNoiseAugmentation
from
comfy.ldm.modules.diffusionmodules.upscaling
import
ImageConcatWithNoiseAugmentation
import
comfy.model_management
import
comfy.model_management
...
@@ -461,3 +462,27 @@ class StableCascade_C(BaseModel):
...
@@ -461,3 +462,27 @@ class StableCascade_C(BaseModel):
out
[
'clip_text'
]
=
comfy
.
conds
.
CONDCrossAttn
(
cross_attn
)
out
[
'clip_text'
]
=
comfy
.
conds
.
CONDCrossAttn
(
cross_attn
)
return
out
return
out
class
StableCascade_B
(
BaseModel
):
def
__init__
(
self
,
model_config
,
model_type
=
ModelType
.
STABLE_CASCADE
,
device
=
None
):
super
().
__init__
(
model_config
,
model_type
,
device
=
device
,
unet_model
=
StageB
)
self
.
diffusion_model
.
eval
().
requires_grad_
(
False
)
def
extra_conds
(
self
,
**
kwargs
):
out
=
{}
noise
=
kwargs
.
get
(
"noise"
,
None
)
clip_text_pooled
=
kwargs
[
"pooled_output"
]
if
clip_text_pooled
is
not
None
:
out
[
'clip_text_pooled'
]
=
comfy
.
conds
.
CONDRegular
(
clip_text_pooled
)
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior
=
kwargs
.
get
(
"stable_cascade_prior"
,
torch
.
zeros
((
1
,
16
,
(
noise
.
shape
[
2
]
*
4
)
//
42
,
(
noise
.
shape
[
3
]
*
4
)
//
42
),
dtype
=
noise
.
dtype
,
layout
=
noise
.
layout
,
device
=
noise
.
device
))
out
[
"effnet"
]
=
comfy
.
conds
.
CONDRegular
(
prior
)
out
[
"sca"
]
=
comfy
.
conds
.
CONDRegular
(
torch
.
zeros
((
1
,)))
cross_attn
=
kwargs
.
get
(
"cross_attn"
,
None
)
if
cross_attn
is
not
None
:
out
[
'clip'
]
=
comfy
.
conds
.
CONDCrossAttn
(
cross_attn
)
return
out
comfy/ops.py
View file @
667c9281
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
import
torch
import
comfy.model_management
import
comfy.model_management
...
@@ -78,7 +96,11 @@ class disable_weight_init:
...
@@ -78,7 +96,11 @@ class disable_weight_init:
return
None
return
None
def
forward_comfy_cast_weights
(
self
,
input
):
def
forward_comfy_cast_weights
(
self
,
input
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
if
self
.
weight
is
not
None
:
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
else
:
weight
=
None
bias
=
None
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
return
torch
.
nn
.
functional
.
layer_norm
(
input
,
self
.
normalized_shape
,
weight
,
bias
,
self
.
eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
...
@@ -87,6 +109,28 @@ class disable_weight_init:
...
@@ -87,6 +109,28 @@ class disable_weight_init:
else
:
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
class
ConvTranspose2d
(
torch
.
nn
.
ConvTranspose2d
):
comfy_cast_weights
=
False
def
reset_parameters
(
self
):
return
None
def
forward_comfy_cast_weights
(
self
,
input
,
output_size
=
None
):
num_spatial_dims
=
2
output_padding
=
self
.
_output_padding
(
input
,
output_size
,
self
.
stride
,
self
.
padding
,
self
.
kernel_size
,
num_spatial_dims
,
self
.
dilation
)
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
return
torch
.
nn
.
functional
.
conv_transpose2d
(
input
,
weight
,
bias
,
self
.
stride
,
self
.
padding
,
output_padding
,
self
.
groups
,
self
.
dilation
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
comfy_cast_weights
:
return
self
.
forward_comfy_cast_weights
(
*
args
,
**
kwargs
)
else
:
return
super
().
forward
(
*
args
,
**
kwargs
)
@
classmethod
@
classmethod
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
def
conv_nd
(
s
,
dims
,
*
args
,
**
kwargs
):
if
dims
==
2
:
if
dims
==
2
:
...
@@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
...
@@ -112,3 +156,6 @@ class manual_cast(disable_weight_init):
class
LayerNorm
(
disable_weight_init
.
LayerNorm
):
class
LayerNorm
(
disable_weight_init
.
LayerNorm
):
comfy_cast_weights
=
True
comfy_cast_weights
=
True
class
ConvTranspose2d
(
disable_weight_init
.
ConvTranspose2d
):
comfy_cast_weights
=
True
comfy/supported_models.py
View file @
667c9281
...
@@ -338,6 +338,20 @@ class Stable_Cascade_C(supported_models_base.BASE):
...
@@ -338,6 +338,20 @@ class Stable_Cascade_C(supported_models_base.BASE):
def
clip_target
(
self
):
def
clip_target
(
self
):
return
None
return
None
class
Stable_Cascade_B
(
Stable_Cascade_C
):
unet_config
=
{
"stable_cascade_stage"
:
'b'
,
}
unet_extra_config
=
{}
latent_format
=
latent_formats
.
SC_B
supported_inference_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
def
get_model
(
self
,
state_dict
,
prefix
=
""
,
device
=
None
):
out
=
model_base
.
StableCascade_B
(
self
,
device
=
device
)
return
out
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
]
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
,
Stable_Cascade_B
]
models
+=
[
SVD_img2vid
]
models
+=
[
SVD_img2vid
]
comfy/utils.py
View file @
667c9281
...
@@ -169,6 +169,8 @@ UNET_MAP_BASIC = {
...
@@ -169,6 +169,8 @@ UNET_MAP_BASIC = {
}
}
def
unet_to_diffusers
(
unet_config
):
def
unet_to_diffusers
(
unet_config
):
if
"num_res_blocks"
not
in
unet_config
:
return
{}
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
channel_mult
=
unet_config
[
"channel_mult"
]
channel_mult
=
unet_config
[
"channel_mult"
]
transformer_depth
=
unet_config
[
"transformer_depth"
][:]
transformer_depth
=
unet_config
[
"transformer_depth"
][:]
...
...
comfy_extras/nodes_stable_cascade.py
0 → 100644
View file @
667c9281
"""
This file is part of ComfyUI.
Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import
torch
import
nodes
class
StableCascade_EmptyLatentImage
:
def
__init__
(
self
,
device
=
"cpu"
):
self
.
device
=
device
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"width"
:
(
"INT"
,
{
"default"
:
1024
,
"min"
:
256
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"height"
:
(
"INT"
,
{
"default"
:
1024
,
"min"
:
256
,
"max"
:
nodes
.
MAX_RESOLUTION
,
"step"
:
8
}),
"compression"
:
(
"INT"
,
{
"default"
:
42
,
"min"
:
32
,
"max"
:
64
,
"step"
:
1
}),
"batch_size"
:
(
"INT"
,
{
"default"
:
1
,
"min"
:
1
,
"max"
:
64
})
}}
RETURN_TYPES
=
(
"LATENT"
,
"LATENT"
)
RETURN_NAMES
=
(
"stage_c"
,
"stage_b"
)
FUNCTION
=
"generate"
CATEGORY
=
"_for_testing/stable_cascade"
def
generate
(
self
,
width
,
height
,
compression
,
batch_size
=
1
):
c_latent
=
torch
.
zeros
([
batch_size
,
16
,
height
//
compression
,
width
//
compression
])
b_latent
=
torch
.
zeros
([
batch_size
,
4
,
height
//
4
,
width
//
4
])
return
({
"samples"
:
c_latent
,
},
{
"samples"
:
b_latent
,
})
class
StableCascade_StageB_Conditioning
:
@
classmethod
def
INPUT_TYPES
(
s
):
return
{
"required"
:
{
"conditioning"
:
(
"CONDITIONING"
,),
"stage_c"
:
(
"LATENT"
,),
}}
RETURN_TYPES
=
(
"CONDITIONING"
,)
FUNCTION
=
"set_prior"
CATEGORY
=
"_for_testing/stable_cascade"
def
set_prior
(
self
,
conditioning
,
stage_c
):
c
=
[]
for
t
in
conditioning
:
d
=
t
[
1
].
copy
()
d
[
'stable_cascade_prior'
]
=
stage_c
[
'samples'
]
n
=
[
t
[
0
],
d
]
c
.
append
(
n
)
return
(
c
,
)
NODE_CLASS_MAPPINGS
=
{
"StableCascade_EmptyLatentImage"
:
StableCascade_EmptyLatentImage
,
"StableCascade_StageB_Conditioning"
:
StableCascade_StageB_Conditioning
,
}
nodes.py
View file @
667c9281
...
@@ -1967,6 +1967,7 @@ def init_custom_nodes():
...
@@ -1967,6 +1967,7 @@ def init_custom_nodes():
"nodes_sdupscale.py"
,
"nodes_sdupscale.py"
,
"nodes_photomaker.py"
,
"nodes_photomaker.py"
,
"nodes_cond.py"
,
"nodes_cond.py"
,
"nodes_stable_cascade.py"
,
]
]
for
node_file
in
extras_files
:
for
node_file
in
extras_files
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment