Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
6ec3f12c
Commit
6ec3f12c
authored
Oct 27, 2023
by
comfyanonymous
Browse files
Support SSD1B model and make it easier to support asymmetric unets.
parent
434ce25e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
154 additions
and
97 deletions
+154
-97
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+26
-21
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+18
-24
comfy/model_detection.py
comfy/model_detection.py
+86
-26
comfy/sd.py
comfy/sd.py
+1
-1
comfy/supported_models.py
comfy/supported_models.py
+13
-3
comfy/utils.py
comfy/utils.py
+10
-22
No files found.
comfy/cldm/cldm.py
View file @
6ec3f12c
...
@@ -27,7 +27,6 @@ class ControlNet(nn.Module):
...
@@ -27,7 +27,6 @@ class ControlNet(nn.Module):
model_channels
,
model_channels
,
hint_channels
,
hint_channels
,
num_res_blocks
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
conv_resample
=
True
,
...
@@ -52,6 +51,7 @@ class ControlNet(nn.Module):
...
@@ -52,6 +51,7 @@ class ControlNet(nn.Module):
use_linear_in_transformer
=
False
,
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
device
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
operations
=
comfy
.
ops
,
):
):
...
@@ -79,10 +79,7 @@ class ControlNet(nn.Module):
...
@@ -79,10 +79,7 @@ class ControlNet(nn.Module):
self
.
image_size
=
image_size
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
model_channels
=
model_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
if
isinstance
(
num_res_blocks
,
int
):
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
else
:
...
@@ -90,18 +87,16 @@ class ControlNet(nn.Module):
...
@@ -90,18 +87,16 @@ class ControlNet(nn.Module):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
transformer_depth
=
transformer_depth
[:]
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
...
@@ -180,11 +175,14 @@ class ControlNet(nn.Module):
...
@@ -180,11 +175,14 @@ class ControlNet(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
operations
=
operations
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
,
)
)
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
dim_head
=
ch
//
num_heads
else
:
else
:
...
@@ -201,9 +199,9 @@ class ControlNet(nn.Module):
...
@@ -201,9 +199,9 @@ class ControlNet(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
layers
.
append
(
SpatialTransformer
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
@@ -223,11 +221,13 @@ class ControlNet(nn.Module):
...
@@ -223,11 +221,13 @@ class ControlNet(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
down
=
True
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
operations
=
operations
)
)
if
resblock_updown
if
resblock_updown
else
Downsample
(
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
operations
=
operations
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
)
)
)
)
...
@@ -245,7 +245,7 @@ class ControlNet(nn.Module):
...
@@ -245,7 +245,7 @@ class ControlNet(nn.Module):
if
legacy
:
if
legacy
:
#num_heads = 1
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
mid_block
=
[
ResBlock
(
ResBlock
(
ch
,
ch
,
time_embed_dim
,
time_embed_dim
,
...
@@ -253,12 +253,15 @@ class ControlNet(nn.Module):
...
@@ -253,12 +253,15 @@ class ControlNet(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
operations
=
operations
),
)]
SpatialTransformer
(
# always uses a self-attn
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
),
),
ResBlock
(
ResBlock
(
ch
,
ch
,
...
@@ -267,9 +270,11 @@ class ControlNet(nn.Module):
...
@@ -267,9 +270,11 @@ class ControlNet(nn.Module):
dims
=
dims
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
operations
=
operations
)
,
)
]
)
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
,
operations
=
operations
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
,
operations
=
operations
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
6ec3f12c
...
@@ -259,10 +259,6 @@ class UNetModel(nn.Module):
...
@@ -259,10 +259,6 @@ class UNetModel(nn.Module):
:param model_channels: base channel count for the model.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
:param conv_resample: if True, use learned convolutions for upsampling and
...
@@ -289,7 +285,6 @@ class UNetModel(nn.Module):
...
@@ -289,7 +285,6 @@ class UNetModel(nn.Module):
model_channels
,
model_channels
,
out_channels
,
out_channels
,
num_res_blocks
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
conv_resample
=
True
,
...
@@ -314,6 +309,7 @@ class UNetModel(nn.Module):
...
@@ -314,6 +309,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer
=
False
,
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
device
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
operations
=
comfy
.
ops
,
):
):
...
@@ -341,10 +337,7 @@ class UNetModel(nn.Module):
...
@@ -341,10 +337,7 @@ class UNetModel(nn.Module):
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
if
isinstance
(
num_res_blocks
,
int
):
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
else
:
...
@@ -352,18 +345,16 @@ class UNetModel(nn.Module):
...
@@ -352,18 +345,16 @@ class UNetModel(nn.Module):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
transformer_depth
=
transformer_depth
[:]
transformer_depth_output
=
transformer_depth_output
[:]
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
self
.
conv_resample
=
conv_resample
...
@@ -428,7 +419,8 @@ class UNetModel(nn.Module):
...
@@ -428,7 +419,8 @@ class UNetModel(nn.Module):
)
)
]
]
ch
=
mult
*
model_channels
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
dim_head
=
ch
//
num_heads
else
:
else
:
...
@@ -444,7 +436,7 @@ class UNetModel(nn.Module):
...
@@ -444,7 +436,7 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransformer
(
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
...
@@ -488,7 +480,7 @@ class UNetModel(nn.Module):
...
@@ -488,7 +480,7 @@ class UNetModel(nn.Module):
if
legacy
:
if
legacy
:
#num_heads = 1
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
mid_block
=
[
ResBlock
(
ResBlock
(
ch
,
ch
,
time_embed_dim
,
time_embed_dim
,
...
@@ -499,8 +491,9 @@ class UNetModel(nn.Module):
...
@@ -499,8 +491,9 @@ class UNetModel(nn.Module):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
device
,
device
=
device
,
operations
=
operations
operations
=
operations
),
)]
SpatialTransformer
(
# always uses a self-attn
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
...
@@ -515,8 +508,8 @@ class UNetModel(nn.Module):
...
@@ -515,8 +508,8 @@ class UNetModel(nn.Module):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
device
,
device
=
device
,
operations
=
operations
operations
=
operations
)
,
)
]
)
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
_feature_size
+=
ch
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
self
.
output_blocks
=
nn
.
ModuleList
([])
...
@@ -538,7 +531,8 @@ class UNetModel(nn.Module):
...
@@ -538,7 +531,8 @@ class UNetModel(nn.Module):
)
)
]
]
ch
=
model_channels
*
mult
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth_output
.
pop
()
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
dim_head
=
ch
//
num_heads
else
:
else
:
...
@@ -555,7 +549,7 @@ class UNetModel(nn.Module):
...
@@ -555,7 +549,7 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
layers
.
append
(
SpatialTransformer
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
...
...
comfy/model_detection.py
View file @
6ec3f12c
...
@@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string):
...
@@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string):
count
+=
1
count
+=
1
return
count
return
count
def
calculate_transformer_depth
(
prefix
,
state_dict_keys
,
state_dict
):
context_dim
=
None
use_linear_in_transformer
=
False
transformer_prefix
=
prefix
+
"1.transformer_blocks."
transformer_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
transformer_prefix
),
state_dict_keys
)))
if
len
(
transformer_keys
)
>
0
:
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
return
None
def
detect_unet_config
(
state_dict
,
key_prefix
,
dtype
):
def
detect_unet_config
(
state_dict
,
key_prefix
,
dtype
):
state_dict_keys
=
list
(
state_dict
.
keys
())
state_dict_keys
=
list
(
state_dict
.
keys
())
...
@@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
...
@@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult
=
[]
channel_mult
=
[]
attention_resolutions
=
[]
attention_resolutions
=
[]
transformer_depth
=
[]
transformer_depth
=
[]
transformer_depth_output
=
[]
context_dim
=
None
context_dim
=
None
use_linear_in_transformer
=
False
use_linear_in_transformer
=
False
...
@@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype):
...
@@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype):
count
=
0
count
=
0
last_res_blocks
=
0
last_res_blocks
=
0
last_transformer_depth
=
0
last_channel_mult
=
0
last_channel_mult
=
0
while
True
:
input_block_count
=
count_blocks
(
state_dict_keys
,
'{}input_blocks'
.
format
(
key_prefix
)
+
'.{}.'
)
for
count
in
range
(
input_block_count
):
prefix
=
'{}input_blocks.{}.'
.
format
(
key_prefix
,
count
)
prefix
=
'{}input_blocks.{}.'
.
format
(
key_prefix
,
count
)
prefix_output
=
'{}output_blocks.{}.'
.
format
(
key_prefix
,
input_block_count
-
count
-
1
)
block_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
prefix
),
state_dict_keys
)))
block_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
prefix
),
state_dict_keys
)))
if
len
(
block_keys
)
==
0
:
if
len
(
block_keys
)
==
0
:
break
break
block_keys_output
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
prefix_output
),
state_dict_keys
)))
if
"{}0.op.weight"
.
format
(
prefix
)
in
block_keys
:
#new layer
if
"{}0.op.weight"
.
format
(
prefix
)
in
block_keys
:
#new layer
if
last_transformer_depth
>
0
:
attention_resolutions
.
append
(
current_res
)
transformer_depth
.
append
(
last_transformer_depth
)
num_res_blocks
.
append
(
last_res_blocks
)
num_res_blocks
.
append
(
last_res_blocks
)
channel_mult
.
append
(
last_channel_mult
)
channel_mult
.
append
(
last_channel_mult
)
current_res
*=
2
current_res
*=
2
last_res_blocks
=
0
last_res_blocks
=
0
last_transformer_depth
=
0
last_channel_mult
=
0
last_channel_mult
=
0
out
=
calculate_transformer_depth
(
prefix_output
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth_output
.
append
(
out
[
0
])
else
:
transformer_depth_output
.
append
(
0
)
else
:
else
:
res_block_prefix
=
"{}0.in_layers.0.weight"
.
format
(
prefix
)
res_block_prefix
=
"{}0.in_layers.0.weight"
.
format
(
prefix
)
if
res_block_prefix
in
block_keys
:
if
res_block_prefix
in
block_keys
:
last_res_blocks
+=
1
last_res_blocks
+=
1
last_channel_mult
=
state_dict
[
"{}0.out_layers.3.weight"
.
format
(
prefix
)].
shape
[
0
]
//
model_channels
last_channel_mult
=
state_dict
[
"{}0.out_layers.3.weight"
.
format
(
prefix
)].
shape
[
0
]
//
model_channels
transformer_prefix
=
prefix
+
"1.transformer_blocks."
out
=
calculate_transformer_depth
(
prefix
,
state_dict_keys
,
state_dict
)
transformer_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
transformer_prefix
),
state_dict_keys
)))
if
out
is
not
None
:
if
len
(
transformer_keys
)
>
0
:
transformer_depth
.
append
(
out
[
0
])
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
if
context_dim
is
None
:
if
context_dim
is
None
:
context_dim
=
out
[
1
]
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
out
[
2
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
else
:
transformer_depth
.
append
(
0
)
res_block_prefix
=
"{}0.in_layers.0.weight"
.
format
(
prefix_output
)
if
res_block_prefix
in
block_keys_output
:
out
=
calculate_transformer_depth
(
prefix_output
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth_output
.
append
(
out
[
0
])
else
:
transformer_depth_output
.
append
(
0
)
count
+=
1
if
last_transformer_depth
>
0
:
attention_resolutions
.
append
(
current_res
)
transformer_depth
.
append
(
last_transformer_depth
)
num_res_blocks
.
append
(
last_res_blocks
)
num_res_blocks
.
append
(
last_res_blocks
)
channel_mult
.
append
(
last_channel_mult
)
channel_mult
.
append
(
last_channel_mult
)
transformer_depth_middle
=
count_blocks
(
state_dict_keys
,
'{}middle_block.1.transformer_blocks.'
.
format
(
key_prefix
)
+
'{}'
)
if
"{}middle_block.1.proj_in.weight"
.
format
(
key_prefix
)
in
state_dict_keys
:
transformer_depth_middle
=
count_blocks
(
state_dict_keys
,
'{}middle_block.1.transformer_blocks.'
.
format
(
key_prefix
)
+
'{}'
)
if
len
(
set
(
num_res_blocks
))
==
1
:
else
:
num_res_blocks
=
num_res_blocks
[
0
]
transformer_depth_middle
=
-
1
if
len
(
set
(
transformer_depth
))
==
1
:
transformer_depth
=
transformer_depth
[
0
]
unet_config
[
"in_channels"
]
=
in_channels
unet_config
[
"in_channels"
]
=
in_channels
unet_config
[
"model_channels"
]
=
model_channels
unet_config
[
"model_channels"
]
=
model_channels
unet_config
[
"num_res_blocks"
]
=
num_res_blocks
unet_config
[
"num_res_blocks"
]
=
num_res_blocks
unet_config
[
"attention_resolutions"
]
=
attention_resolutions
unet_config
[
"transformer_depth"
]
=
transformer_depth
unet_config
[
"transformer_depth"
]
=
transformer_depth
unet_config
[
"transformer_depth_output"
]
=
transformer_depth_output
unet_config
[
"channel_mult"
]
=
channel_mult
unet_config
[
"channel_mult"
]
=
channel_mult
unet_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
unet_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
unet_config
[
'use_linear_in_transformer'
]
=
use_linear_in_transformer
unet_config
[
'use_linear_in_transformer'
]
=
use_linear_in_transformer
...
@@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
...
@@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
else
:
else
:
return
model_config
return
model_config
def
convert_config
(
unet_config
):
new_config
=
unet_config
.
copy
()
num_res_blocks
=
new_config
.
get
(
"num_res_blocks"
,
None
)
channel_mult
=
new_config
.
get
(
"channel_mult"
,
None
)
if
isinstance
(
num_res_blocks
,
int
):
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
if
"attention_resolutions"
in
new_config
:
attention_resolutions
=
new_config
.
pop
(
"attention_resolutions"
)
transformer_depth
=
new_config
.
get
(
"transformer_depth"
,
None
)
transformer_depth_middle
=
new_config
.
get
(
"transformer_depth_middle"
,
None
)
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
t_in
=
[]
t_out
=
[]
s
=
1
for
i
in
range
(
len
(
num_res_blocks
)):
res
=
num_res_blocks
[
i
]
d
=
0
if
s
in
attention_resolutions
:
d
=
transformer_depth
[
i
]
t_in
+=
[
d
]
*
res
t_out
+=
[
d
]
*
(
res
+
1
)
s
*=
2
transformer_depth
=
t_in
transformer_depth_output
=
t_out
new_config
[
"transformer_depth"
]
=
t_in
new_config
[
"transformer_depth_output"
]
=
t_out
new_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
new_config
[
"num_res_blocks"
]
=
num_res_blocks
return
new_config
def
unet_config_from_diffusers_unet
(
state_dict
,
dtype
):
def
unet_config_from_diffusers_unet
(
state_dict
,
dtype
):
match
=
{}
match
=
{}
attention_resolutions
=
[]
attention_resolutions
=
[]
...
@@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
...
@@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
matches
=
False
matches
=
False
break
break
if
matches
:
if
matches
:
return
unet_config
return
convert_config
(
unet_config
)
return
None
return
None
def
model_config_from_diffusers_unet
(
state_dict
,
dtype
):
def
model_config_from_diffusers_unet
(
state_dict
,
dtype
):
...
...
comfy/sd.py
View file @
6ec3f12c
...
@@ -360,7 +360,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
...
@@ -360,7 +360,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
from
.
import
latent_formats
from
.
import
latent_formats
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
model_config
.
unet_config
=
unet_config
model_config
.
unet_config
=
model_detection
.
convert_config
(
unet_config
)
if
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
if
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
model_type
=
model_type
)
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
model_type
=
model_type
)
...
...
comfy/supported_models.py
View file @
6ec3f12c
...
@@ -104,7 +104,7 @@ class SDXLRefiner(supported_models_base.BASE):
...
@@ -104,7 +104,7 @@ class SDXLRefiner(supported_models_base.BASE):
"use_linear_in_transformer"
:
True
,
"use_linear_in_transformer"
:
True
,
"context_dim"
:
1280
,
"context_dim"
:
1280
,
"adm_in_channels"
:
2560
,
"adm_in_channels"
:
2560
,
"transformer_depth"
:
[
0
,
4
,
4
,
0
],
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
}
}
latent_format
=
latent_formats
.
SDXL
latent_format
=
latent_formats
.
SDXL
...
@@ -139,7 +139,7 @@ class SDXL(supported_models_base.BASE):
...
@@ -139,7 +139,7 @@ class SDXL(supported_models_base.BASE):
unet_config
=
{
unet_config
=
{
"model_channels"
:
320
,
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
2
,
10
],
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"context_dim"
:
2048
,
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
"adm_in_channels"
:
2816
}
}
...
@@ -165,6 +165,7 @@ class SDXL(supported_models_base.BASE):
...
@@ -165,6 +165,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix
[
"conditioner.embedders.0.transformer.text_model"
]
=
"cond_stage_model.clip_l.transformer.text_model"
replace_prefix
[
"conditioner.embedders.0.transformer.text_model"
]
=
"cond_stage_model.clip_l.transformer.text_model"
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"conditioner.embedders.1.model."
,
"cond_stage_model.clip_g.transformer.text_model."
,
32
)
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"conditioner.embedders.1.model."
,
"cond_stage_model.clip_g.transformer.text_model."
,
32
)
keys_to_replace
[
"conditioner.embedders.1.model.text_projection"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.text_projection"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.text_projection.weight"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.logit_scale"
]
=
"cond_stage_model.clip_g.logit_scale"
keys_to_replace
[
"conditioner.embedders.1.model.logit_scale"
]
=
"cond_stage_model.clip_g.logit_scale"
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
...
@@ -189,5 +190,14 @@ class SDXL(supported_models_base.BASE):
...
@@ -189,5 +190,14 @@ class SDXL(supported_models_base.BASE):
def
clip_target
(
self
):
def
clip_target
(
self
):
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
SDXLTokenizer
,
sdxl_clip
.
SDXLClipModel
)
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
SDXLTokenizer
,
sdxl_clip
.
SDXLClipModel
)
class
SSD1B
(
SDXL
):
unet_config
=
{
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
}
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
]
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
comfy/utils.py
View file @
6ec3f12c
...
@@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
...
@@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
def
unet_to_diffusers
(
unet_config
):
def
unet_to_diffusers
(
unet_config
):
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
attention_resolutions
=
unet_config
[
"attention_resolutions"
]
channel_mult
=
unet_config
[
"channel_mult"
]
channel_mult
=
unet_config
[
"channel_mult"
]
transformer_depth
=
unet_config
[
"transformer_depth"
]
transformer_depth
=
unet_config
[
"transformer_depth"
][:]
transformer_depth_output
=
unet_config
[
"transformer_depth_output"
][:]
num_blocks
=
len
(
channel_mult
)
num_blocks
=
len
(
channel_mult
)
if
isinstance
(
num_res_blocks
,
int
):
num_res_blocks
=
[
num_res_blocks
]
*
num_blocks
transformers_mid
=
unet_config
.
get
(
"transformer_depth_middle"
,
None
)
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
[
transformer_depth
]
*
num_blocks
transformers_per_layer
=
[]
res
=
1
for
i
in
range
(
num_blocks
):
transformers
=
0
if
res
in
attention_resolutions
:
transformers
=
transformer_depth
[
i
]
transformers_per_layer
.
append
(
transformers
)
res
*=
2
transformers_mid
=
unet_config
.
get
(
"transformer_depth_middle"
,
transformer_depth
[
-
1
])
diffusers_unet_map
=
{}
diffusers_unet_map
=
{}
for
x
in
range
(
num_blocks
):
for
x
in
range
(
num_blocks
):
...
@@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
...
@@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
for
i
in
range
(
num_res_blocks
[
x
]):
for
i
in
range
(
num_res_blocks
[
x
]):
for
b
in
UNET_MAP_RESNET
:
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"down_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"input_blocks.{}.0.{}"
.
format
(
n
,
b
)
diffusers_unet_map
[
"down_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"input_blocks.{}.0.{}"
.
format
(
n
,
b
)
if
transformers_per_layer
[
x
]
>
0
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
for
b
in
UNET_MAP_ATTENTIONS
:
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"input_blocks.{}.1.{}"
.
format
(
n
,
b
)
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"input_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers
_per_layer
[
x
]
):
for
t
in
range
(
num_
transformers
):
for
b
in
TRANSFORMER_BLOCKS
:
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"input_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"input_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
n
+=
1
n
+=
1
...
@@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
...
@@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
diffusers_unet_map
[
"mid_block.resnets.{}.{}"
.
format
(
i
,
UNET_MAP_RESNET
[
b
])]
=
"middle_block.{}.{}"
.
format
(
n
,
b
)
diffusers_unet_map
[
"mid_block.resnets.{}.{}"
.
format
(
i
,
UNET_MAP_RESNET
[
b
])]
=
"middle_block.{}.{}"
.
format
(
n
,
b
)
num_res_blocks
=
list
(
reversed
(
num_res_blocks
))
num_res_blocks
=
list
(
reversed
(
num_res_blocks
))
transformers_per_layer
=
list
(
reversed
(
transformers_per_layer
))
for
x
in
range
(
num_blocks
):
for
x
in
range
(
num_blocks
):
n
=
(
num_res_blocks
[
x
]
+
1
)
*
x
n
=
(
num_res_blocks
[
x
]
+
1
)
*
x
l
=
num_res_blocks
[
x
]
+
1
l
=
num_res_blocks
[
x
]
+
1
...
@@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
...
@@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
for
b
in
UNET_MAP_RESNET
:
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"up_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"output_blocks.{}.0.{}"
.
format
(
n
,
b
)
diffusers_unet_map
[
"up_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"output_blocks.{}.0.{}"
.
format
(
n
,
b
)
c
+=
1
c
+=
1
if
transformers_per_layer
[
x
]
>
0
:
num_transformers
=
transformer_depth_output
.
pop
()
if
num_transformers
>
0
:
c
+=
1
c
+=
1
for
b
in
UNET_MAP_ATTENTIONS
:
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"output_blocks.{}.1.{}"
.
format
(
n
,
b
)
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"output_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers
_per_layer
[
x
]
):
for
t
in
range
(
num_
transformers
):
for
b
in
TRANSFORMER_BLOCKS
:
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"output_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"output_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
if
i
==
l
-
1
:
if
i
==
l
-
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment