Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
6ec3f12c
Commit
6ec3f12c
authored
Oct 27, 2023
by
comfyanonymous
Browse files
Support SSD1B model and make it easier to support asymmetric unets.
parent
434ce25e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
154 additions
and
97 deletions
+154
-97
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+26
-21
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+18
-24
comfy/model_detection.py
comfy/model_detection.py
+86
-26
comfy/sd.py
comfy/sd.py
+1
-1
comfy/supported_models.py
comfy/supported_models.py
+13
-3
comfy/utils.py
comfy/utils.py
+10
-22
No files found.
comfy/cldm/cldm.py
View file @
6ec3f12c
...
...
@@ -27,7 +27,6 @@ class ControlNet(nn.Module):
model_channels
,
hint_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
...
...
@@ -52,6 +51,7 @@ class ControlNet(nn.Module):
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
):
...
...
@@ -79,10 +79,7 @@ class ControlNet(nn.Module):
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
...
...
@@ -90,18 +87,16 @@ class ControlNet(nn.Module):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
transformer_depth
=
transformer_depth
[:]
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
...
...
@@ -180,11 +175,14 @@ class ControlNet(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
operations
=
operations
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
...
...
@@ -201,9 +199,9 @@ class ControlNet(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
@@ -223,11 +221,13 @@ class ControlNet(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
operations
=
operations
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
)
...
...
@@ -245,7 +245,7 @@ class ControlNet(nn.Module):
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
mid_block
=
[
ResBlock
(
ch
,
time_embed_dim
,
...
...
@@ -253,12 +253,15 @@ class ControlNet(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
),
SpatialTransformer
(
# always uses a self-attn
)]
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
),
ResBlock
(
ch
,
...
...
@@ -267,9 +270,11 @@ class ControlNet(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
,
)
)
]
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
,
operations
=
operations
)
self
.
_feature_size
+=
ch
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
6ec3f12c
...
...
@@ -259,10 +259,6 @@ class UNetModel(nn.Module):
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
...
...
@@ -289,7 +285,6 @@ class UNetModel(nn.Module):
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
...
...
@@ -314,6 +309,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
):
...
...
@@ -341,10 +337,7 @@ class UNetModel(nn.Module):
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
...
...
@@ -352,18 +345,16 @@ class UNetModel(nn.Module):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
transformer_depth
=
transformer_depth
[:]
transformer_depth_output
=
transformer_depth_output
[:]
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
...
...
@@ -428,7 +419,8 @@ class UNetModel(nn.Module):
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
...
...
@@ -444,7 +436,7 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
...
...
@@ -488,7 +480,7 @@ class UNetModel(nn.Module):
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
mid_block
=
[
ResBlock
(
ch
,
time_embed_dim
,
...
...
@@ -499,8 +491,9 @@ class UNetModel(nn.Module):
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
),
SpatialTransformer
(
# always uses a self-attn
)]
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
...
...
@@ -515,8 +508,8 @@ class UNetModel(nn.Module):
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
,
)
)
]
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
...
...
@@ -538,7 +531,8 @@ class UNetModel(nn.Module):
)
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth_output
.
pop
()
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
...
...
@@ -555,7 +549,7 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
...
...
comfy/model_detection.py
View file @
6ec3f12c
...
...
@@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string):
count
+=
1
return
count
def
calculate_transformer_depth
(
prefix
,
state_dict_keys
,
state_dict
):
context_dim
=
None
use_linear_in_transformer
=
False
transformer_prefix
=
prefix
+
"1.transformer_blocks."
transformer_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
transformer_prefix
),
state_dict_keys
)))
if
len
(
transformer_keys
)
>
0
:
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
return
None
def
detect_unet_config
(
state_dict
,
key_prefix
,
dtype
):
state_dict_keys
=
list
(
state_dict
.
keys
())
...
...
@@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult
=
[]
attention_resolutions
=
[]
transformer_depth
=
[]
transformer_depth_output
=
[]
context_dim
=
None
use_linear_in_transformer
=
False
...
...
@@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype):
count
=
0
last_res_blocks
=
0
last_transformer_depth
=
0
last_channel_mult
=
0
while
True
:
input_block_count
=
count_blocks
(
state_dict_keys
,
'{}input_blocks'
.
format
(
key_prefix
)
+
'.{}.'
)
for
count
in
range
(
input_block_count
):
prefix
=
'{}input_blocks.{}.'
.
format
(
key_prefix
,
count
)
prefix_output
=
'{}output_blocks.{}.'
.
format
(
key_prefix
,
input_block_count
-
count
-
1
)
block_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
prefix
),
state_dict_keys
)))
if
len
(
block_keys
)
==
0
:
break
block_keys_output
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
prefix_output
),
state_dict_keys
)))
if
"{}0.op.weight"
.
format
(
prefix
)
in
block_keys
:
#new layer
if
last_transformer_depth
>
0
:
attention_resolutions
.
append
(
current_res
)
transformer_depth
.
append
(
last_transformer_depth
)
num_res_blocks
.
append
(
last_res_blocks
)
channel_mult
.
append
(
last_channel_mult
)
current_res
*=
2
last_res_blocks
=
0
last_transformer_depth
=
0
last_channel_mult
=
0
out
=
calculate_transformer_depth
(
prefix_output
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth_output
.
append
(
out
[
0
])
else
:
transformer_depth_output
.
append
(
0
)
else
:
res_block_prefix
=
"{}0.in_layers.0.weight"
.
format
(
prefix
)
if
res_block_prefix
in
block_keys
:
last_res_blocks
+=
1
last_channel_mult
=
state_dict
[
"{}0.out_layers.3.weight"
.
format
(
prefix
)].
shape
[
0
]
//
model_channels
transformer_prefix
=
prefix
+
"1.transformer_blocks."
transformer_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
transformer_prefix
),
state_dict_keys
)))
if
len
(
transformer_keys
)
>
0
:
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
out
=
calculate_transformer_depth
(
prefix
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth
.
append
(
out
[
0
])
if
context_dim
is
None
:
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
context_dim
=
out
[
1
]
use_linear_in_transformer
=
out
[
2
]
else
:
transformer_depth
.
append
(
0
)
res_block_prefix
=
"{}0.in_layers.0.weight"
.
format
(
prefix_output
)
if
res_block_prefix
in
block_keys_output
:
out
=
calculate_transformer_depth
(
prefix_output
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth_output
.
append
(
out
[
0
])
else
:
transformer_depth_output
.
append
(
0
)
count
+=
1
if
last_transformer_depth
>
0
:
attention_resolutions
.
append
(
current_res
)
transformer_depth
.
append
(
last_transformer_depth
)
num_res_blocks
.
append
(
last_res_blocks
)
channel_mult
.
append
(
last_channel_mult
)
if
"{}middle_block.1.proj_in.weight"
.
format
(
key_prefix
)
in
state_dict_keys
:
transformer_depth_middle
=
count_blocks
(
state_dict_keys
,
'{}middle_block.1.transformer_blocks.'
.
format
(
key_prefix
)
+
'{}'
)
if
len
(
set
(
num_res_blocks
))
==
1
:
num_res_blocks
=
num_res_blocks
[
0
]
if
len
(
set
(
transformer_depth
))
==
1
:
transformer_depth
=
transformer_depth
[
0
]
else
:
transformer_depth_middle
=
-
1
unet_config
[
"in_channels"
]
=
in_channels
unet_config
[
"model_channels"
]
=
model_channels
unet_config
[
"num_res_blocks"
]
=
num_res_blocks
unet_config
[
"attention_resolutions"
]
=
attention_resolutions
unet_config
[
"transformer_depth"
]
=
transformer_depth
unet_config
[
"transformer_depth_output"
]
=
transformer_depth_output
unet_config
[
"channel_mult"
]
=
channel_mult
unet_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
unet_config
[
'use_linear_in_transformer'
]
=
use_linear_in_transformer
...
...
@@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
else
:
return
model_config
def
convert_config
(
unet_config
):
new_config
=
unet_config
.
copy
()
num_res_blocks
=
new_config
.
get
(
"num_res_blocks"
,
None
)
channel_mult
=
new_config
.
get
(
"channel_mult"
,
None
)
if
isinstance
(
num_res_blocks
,
int
):
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
if
"attention_resolutions"
in
new_config
:
attention_resolutions
=
new_config
.
pop
(
"attention_resolutions"
)
transformer_depth
=
new_config
.
get
(
"transformer_depth"
,
None
)
transformer_depth_middle
=
new_config
.
get
(
"transformer_depth_middle"
,
None
)
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
t_in
=
[]
t_out
=
[]
s
=
1
for
i
in
range
(
len
(
num_res_blocks
)):
res
=
num_res_blocks
[
i
]
d
=
0
if
s
in
attention_resolutions
:
d
=
transformer_depth
[
i
]
t_in
+=
[
d
]
*
res
t_out
+=
[
d
]
*
(
res
+
1
)
s
*=
2
transformer_depth
=
t_in
transformer_depth_output
=
t_out
new_config
[
"transformer_depth"
]
=
t_in
new_config
[
"transformer_depth_output"
]
=
t_out
new_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
new_config
[
"num_res_blocks"
]
=
num_res_blocks
return
new_config
def
unet_config_from_diffusers_unet
(
state_dict
,
dtype
):
match
=
{}
attention_resolutions
=
[]
...
...
@@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
matches
=
False
break
if
matches
:
return
unet_config
return
convert_config
(
unet_config
)
return
None
def
model_config_from_diffusers_unet
(
state_dict
,
dtype
):
...
...
comfy/sd.py
View file @
6ec3f12c
...
...
@@ -360,7 +360,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
from
.
import
latent_formats
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
model_config
.
unet_config
=
unet_config
model_config
.
unet_config
=
model_detection
.
convert_config
(
unet_config
)
if
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
model_type
=
model_type
)
...
...
comfy/supported_models.py
View file @
6ec3f12c
...
...
@@ -104,7 +104,7 @@ class SDXLRefiner(supported_models_base.BASE):
"use_linear_in_transformer"
:
True
,
"context_dim"
:
1280
,
"adm_in_channels"
:
2560
,
"transformer_depth"
:
[
0
,
4
,
4
,
0
],
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -139,7 +139,7 @@ class SDXL(supported_models_base.BASE):
unet_config
=
{
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
2
,
10
],
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
}
...
...
@@ -165,6 +165,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix
[
"conditioner.embedders.0.transformer.text_model"
]
=
"cond_stage_model.clip_l.transformer.text_model"
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"conditioner.embedders.1.model."
,
"cond_stage_model.clip_g.transformer.text_model."
,
32
)
keys_to_replace
[
"conditioner.embedders.1.model.text_projection"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.text_projection.weight"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.logit_scale"
]
=
"cond_stage_model.clip_g.logit_scale"
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
...
...
@@ -189,5 +190,14 @@ class SDXL(supported_models_base.BASE):
def
clip_target
(
self
):
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
SDXLTokenizer
,
sdxl_clip
.
SDXLClipModel
)
class
SSD1B
(
SDXL
):
unet_config
=
{
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
}
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
]
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
comfy/utils.py
View file @
6ec3f12c
...
...
@@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
def
unet_to_diffusers
(
unet_config
):
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
attention_resolutions
=
unet_config
[
"attention_resolutions"
]
channel_mult
=
unet_config
[
"channel_mult"
]
transformer_depth
=
unet_config
[
"transformer_depth"
]
transformer_depth
=
unet_config
[
"transformer_depth"
][:]
transformer_depth_output
=
unet_config
[
"transformer_depth_output"
][:]
num_blocks
=
len
(
channel_mult
)
if
isinstance
(
num_res_blocks
,
int
):
num_res_blocks
=
[
num_res_blocks
]
*
num_blocks
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
[
transformer_depth
]
*
num_blocks
transformers_per_layer
=
[]
res
=
1
for
i
in
range
(
num_blocks
):
transformers
=
0
if
res
in
attention_resolutions
:
transformers
=
transformer_depth
[
i
]
transformers_per_layer
.
append
(
transformers
)
res
*=
2
transformers_mid
=
unet_config
.
get
(
"transformer_depth_middle"
,
transformer_depth
[
-
1
])
transformers_mid
=
unet_config
.
get
(
"transformer_depth_middle"
,
None
)
diffusers_unet_map
=
{}
for
x
in
range
(
num_blocks
):
...
...
@@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
for
i
in
range
(
num_res_blocks
[
x
]):
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"down_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"input_blocks.{}.0.{}"
.
format
(
n
,
b
)
if
transformers_per_layer
[
x
]
>
0
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"input_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers
_per_layer
[
x
]
):
for
t
in
range
(
num_
transformers
):
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"input_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
n
+=
1
...
...
@@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
diffusers_unet_map
[
"mid_block.resnets.{}.{}"
.
format
(
i
,
UNET_MAP_RESNET
[
b
])]
=
"middle_block.{}.{}"
.
format
(
n
,
b
)
num_res_blocks
=
list
(
reversed
(
num_res_blocks
))
transformers_per_layer
=
list
(
reversed
(
transformers_per_layer
))
for
x
in
range
(
num_blocks
):
n
=
(
num_res_blocks
[
x
]
+
1
)
*
x
l
=
num_res_blocks
[
x
]
+
1
...
...
@@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"up_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"output_blocks.{}.0.{}"
.
format
(
n
,
b
)
c
+=
1
if
transformers_per_layer
[
x
]
>
0
:
num_transformers
=
transformer_depth_output
.
pop
()
if
num_transformers
>
0
:
c
+=
1
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"output_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers
_per_layer
[
x
]
):
for
t
in
range
(
num_
transformers
):
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"output_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
if
i
==
l
-
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment