Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
6ec3f12c
"docs/vscode:/vscode.git/clone" did not exist on "39d24d9ddb60ec6f43ac9255753a9ea315bd0349"
Commit
6ec3f12c
authored
Oct 27, 2023
by
comfyanonymous
Browse files
Support SSD1B model and make it easier to support asymmetric unets.
parent
434ce25e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
154 additions
and
97 deletions
+154
-97
comfy/cldm/cldm.py
comfy/cldm/cldm.py
+26
-21
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+18
-24
comfy/model_detection.py
comfy/model_detection.py
+86
-26
comfy/sd.py
comfy/sd.py
+1
-1
comfy/supported_models.py
comfy/supported_models.py
+13
-3
comfy/utils.py
comfy/utils.py
+10
-22
No files found.
comfy/cldm/cldm.py
View file @
6ec3f12c
...
...
@@ -27,7 +27,6 @@ class ControlNet(nn.Module):
model_channels
,
hint_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
...
...
@@ -52,6 +51,7 @@ class ControlNet(nn.Module):
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
):
...
...
@@ -79,10 +79,7 @@ class ControlNet(nn.Module):
self
.
image_size
=
image_size
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
...
...
@@ -90,18 +87,16 @@ class ControlNet(nn.Module):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
transformer_depth
=
transformer_depth
[:]
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
...
...
@@ -180,11 +175,14 @@ class ControlNet(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
operations
=
operations
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
,
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
...
...
@@ -201,9 +199,9 @@ class ControlNet(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
self
.
input_blocks
.
append
(
TimestepEmbedSequential
(
*
layers
))
...
...
@@ -223,11 +221,13 @@ class ControlNet(nn.Module):
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
down
=
True
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
if
resblock_updown
else
Downsample
(
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
operations
=
operations
ch
,
conv_resample
,
dims
=
dims
,
out_channels
=
out_ch
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
)
)
...
...
@@ -245,7 +245,7 @@ class ControlNet(nn.Module):
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
mid_block
=
[
ResBlock
(
ch
,
time_embed_dim
,
...
...
@@ -253,12 +253,15 @@ class ControlNet(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
),
SpatialTransformer
(
# always uses a self-attn
)]
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
operations
=
operations
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
),
ResBlock
(
ch
,
...
...
@@ -267,9 +270,11 @@ class ControlNet(nn.Module):
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
,
)
)
]
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
middle_block_out
=
self
.
make_zero_conv
(
ch
,
operations
=
operations
)
self
.
_feature_size
+=
ch
...
...
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
6ec3f12c
...
...
@@ -259,10 +259,6 @@ class UNetModel(nn.Module):
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
...
...
@@ -289,7 +285,6 @@ class UNetModel(nn.Module):
model_channels
,
out_channels
,
num_res_blocks
,
attention_resolutions
,
dropout
=
0
,
channel_mult
=
(
1
,
2
,
4
,
8
),
conv_resample
=
True
,
...
...
@@ -314,6 +309,7 @@ class UNetModel(nn.Module):
use_linear_in_transformer
=
False
,
adm_in_channels
=
None
,
transformer_depth_middle
=
None
,
transformer_depth_output
=
None
,
device
=
None
,
operations
=
comfy
.
ops
,
):
...
...
@@ -341,10 +337,7 @@ class UNetModel(nn.Module):
self
.
in_channels
=
in_channels
self
.
model_channels
=
model_channels
self
.
out_channels
=
out_channels
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
if
isinstance
(
num_res_blocks
,
int
):
self
.
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
else
:
...
...
@@ -352,18 +345,16 @@ class UNetModel(nn.Module):
raise
ValueError
(
"provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult"
)
self
.
num_res_blocks
=
num_res_blocks
if
disable_self_attentions
is
not
None
:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert
len
(
disable_self_attentions
)
==
len
(
channel_mult
)
if
num_attention_blocks
is
not
None
:
assert
len
(
num_attention_blocks
)
==
len
(
self
.
num_res_blocks
)
assert
all
(
map
(
lambda
i
:
self
.
num_res_blocks
[
i
]
>=
num_attention_blocks
[
i
],
range
(
len
(
num_attention_blocks
))))
print
(
f
"Constructor of UNetModel received num_attention_blocks=
{
num_attention_blocks
}
. "
f
"This option has LESS priority than attention_resolutions
{
attention_resolutions
}
, "
f
"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f
"attention will still not be set."
)
self
.
attention_resolutions
=
attention_resolutions
transformer_depth
=
transformer_depth
[:]
transformer_depth_output
=
transformer_depth_output
[:]
self
.
dropout
=
dropout
self
.
channel_mult
=
channel_mult
self
.
conv_resample
=
conv_resample
...
...
@@ -428,7 +419,8 @@ class UNetModel(nn.Module):
)
]
ch
=
mult
*
model_channels
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
...
...
@@ -444,7 +436,7 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
nr
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
...
...
@@ -488,7 +480,7 @@ class UNetModel(nn.Module):
if
legacy
:
#num_heads = 1
dim_head
=
ch
//
num_heads
if
use_spatial_transformer
else
num_head_channels
self
.
middle_block
=
TimestepEmbedSequential
(
mid_block
=
[
ResBlock
(
ch
,
time_embed_dim
,
...
...
@@ -499,8 +491,9 @@ class UNetModel(nn.Module):
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
),
SpatialTransformer
(
# always uses a self-attn
)]
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
SpatialTransformer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
...
...
@@ -515,8 +508,8 @@ class UNetModel(nn.Module):
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
,
)
)
]
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
...
...
@@ -538,7 +531,8 @@ class UNetModel(nn.Module):
)
]
ch
=
model_channels
*
mult
if
ds
in
attention_resolutions
:
num_transformers
=
transformer_depth_output
.
pop
()
if
num_transformers
>
0
:
if
num_head_channels
==
-
1
:
dim_head
=
ch
//
num_heads
else
:
...
...
@@ -555,7 +549,7 @@ class UNetModel(nn.Module):
if
not
exists
(
num_attention_blocks
)
or
i
<
num_attention_blocks
[
level
]:
layers
.
append
(
SpatialTransformer
(
ch
,
num_heads
,
dim_head
,
depth
=
transformer
_depth
[
level
]
,
context_dim
=
context_dim
,
ch
,
num_heads
,
dim_head
,
depth
=
num_
transformer
s
,
context_dim
=
context_dim
,
disable_self_attn
=
disabled_sa
,
use_linear
=
use_linear_in_transformer
,
use_checkpoint
=
use_checkpoint
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)
...
...
comfy/model_detection.py
View file @
6ec3f12c
...
...
@@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string):
count
+=
1
return
count
def
calculate_transformer_depth
(
prefix
,
state_dict_keys
,
state_dict
):
context_dim
=
None
use_linear_in_transformer
=
False
transformer_prefix
=
prefix
+
"1.transformer_blocks."
transformer_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
transformer_prefix
),
state_dict_keys
)))
if
len
(
transformer_keys
)
>
0
:
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
return
last_transformer_depth
,
context_dim
,
use_linear_in_transformer
return
None
def
detect_unet_config
(
state_dict
,
key_prefix
,
dtype
):
state_dict_keys
=
list
(
state_dict
.
keys
())
...
...
@@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
channel_mult
=
[]
attention_resolutions
=
[]
transformer_depth
=
[]
transformer_depth_output
=
[]
context_dim
=
None
use_linear_in_transformer
=
False
...
...
@@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype):
count
=
0
last_res_blocks
=
0
last_transformer_depth
=
0
last_channel_mult
=
0
while
True
:
input_block_count
=
count_blocks
(
state_dict_keys
,
'{}input_blocks'
.
format
(
key_prefix
)
+
'.{}.'
)
for
count
in
range
(
input_block_count
):
prefix
=
'{}input_blocks.{}.'
.
format
(
key_prefix
,
count
)
prefix_output
=
'{}output_blocks.{}.'
.
format
(
key_prefix
,
input_block_count
-
count
-
1
)
block_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
prefix
),
state_dict_keys
)))
if
len
(
block_keys
)
==
0
:
break
block_keys_output
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
prefix_output
),
state_dict_keys
)))
if
"{}0.op.weight"
.
format
(
prefix
)
in
block_keys
:
#new layer
if
last_transformer_depth
>
0
:
attention_resolutions
.
append
(
current_res
)
transformer_depth
.
append
(
last_transformer_depth
)
num_res_blocks
.
append
(
last_res_blocks
)
channel_mult
.
append
(
last_channel_mult
)
current_res
*=
2
last_res_blocks
=
0
last_transformer_depth
=
0
last_channel_mult
=
0
out
=
calculate_transformer_depth
(
prefix_output
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth_output
.
append
(
out
[
0
])
else
:
transformer_depth_output
.
append
(
0
)
else
:
res_block_prefix
=
"{}0.in_layers.0.weight"
.
format
(
prefix
)
if
res_block_prefix
in
block_keys
:
last_res_blocks
+=
1
last_channel_mult
=
state_dict
[
"{}0.out_layers.3.weight"
.
format
(
prefix
)].
shape
[
0
]
//
model_channels
transformer_prefix
=
prefix
+
"1.transformer_blocks."
transformer_keys
=
sorted
(
list
(
filter
(
lambda
a
:
a
.
startswith
(
transformer_prefix
),
state_dict_keys
)))
if
len
(
transformer_keys
)
>
0
:
last_transformer_depth
=
count_blocks
(
state_dict_keys
,
transformer_prefix
+
'{}'
)
if
context_dim
is
None
:
context_dim
=
state_dict
[
'{}0.attn2.to_k.weight'
.
format
(
transformer_prefix
)].
shape
[
1
]
use_linear_in_transformer
=
len
(
state_dict
[
'{}1.proj_in.weight'
.
format
(
prefix
)].
shape
)
==
2
out
=
calculate_transformer_depth
(
prefix
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth
.
append
(
out
[
0
])
if
context_dim
is
None
:
context_dim
=
out
[
1
]
use_linear_in_transformer
=
out
[
2
]
else
:
transformer_depth
.
append
(
0
)
res_block_prefix
=
"{}0.in_layers.0.weight"
.
format
(
prefix_output
)
if
res_block_prefix
in
block_keys_output
:
out
=
calculate_transformer_depth
(
prefix_output
,
state_dict_keys
,
state_dict
)
if
out
is
not
None
:
transformer_depth_output
.
append
(
out
[
0
])
else
:
transformer_depth_output
.
append
(
0
)
count
+=
1
if
last_transformer_depth
>
0
:
attention_resolutions
.
append
(
current_res
)
transformer_depth
.
append
(
last_transformer_depth
)
num_res_blocks
.
append
(
last_res_blocks
)
channel_mult
.
append
(
last_channel_mult
)
transformer_depth_middle
=
count_blocks
(
state_dict_keys
,
'{}middle_block.1.transformer_blocks.'
.
format
(
key_prefix
)
+
'{}'
)
if
len
(
set
(
num_res_blocks
))
==
1
:
num_res_blocks
=
num_res_blocks
[
0
]
if
len
(
set
(
transformer_depth
))
==
1
:
transformer_depth
=
transformer_depth
[
0
]
if
"{}middle_block.1.proj_in.weight"
.
format
(
key_prefix
)
in
state_dict_keys
:
transformer_depth_middle
=
count_blocks
(
state_dict_keys
,
'{}middle_block.1.transformer_blocks.'
.
format
(
key_prefix
)
+
'{}'
)
else
:
transformer_depth_middle
=
-
1
unet_config
[
"in_channels"
]
=
in_channels
unet_config
[
"model_channels"
]
=
model_channels
unet_config
[
"num_res_blocks"
]
=
num_res_blocks
unet_config
[
"attention_resolutions"
]
=
attention_resolutions
unet_config
[
"transformer_depth"
]
=
transformer_depth
unet_config
[
"transformer_depth_output"
]
=
transformer_depth_output
unet_config
[
"channel_mult"
]
=
channel_mult
unet_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
unet_config
[
'use_linear_in_transformer'
]
=
use_linear_in_transformer
...
...
@@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
else
:
return
model_config
def
convert_config
(
unet_config
):
new_config
=
unet_config
.
copy
()
num_res_blocks
=
new_config
.
get
(
"num_res_blocks"
,
None
)
channel_mult
=
new_config
.
get
(
"channel_mult"
,
None
)
if
isinstance
(
num_res_blocks
,
int
):
num_res_blocks
=
len
(
channel_mult
)
*
[
num_res_blocks
]
if
"attention_resolutions"
in
new_config
:
attention_resolutions
=
new_config
.
pop
(
"attention_resolutions"
)
transformer_depth
=
new_config
.
get
(
"transformer_depth"
,
None
)
transformer_depth_middle
=
new_config
.
get
(
"transformer_depth_middle"
,
None
)
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
len
(
channel_mult
)
*
[
transformer_depth
]
if
transformer_depth_middle
is
None
:
transformer_depth_middle
=
transformer_depth
[
-
1
]
t_in
=
[]
t_out
=
[]
s
=
1
for
i
in
range
(
len
(
num_res_blocks
)):
res
=
num_res_blocks
[
i
]
d
=
0
if
s
in
attention_resolutions
:
d
=
transformer_depth
[
i
]
t_in
+=
[
d
]
*
res
t_out
+=
[
d
]
*
(
res
+
1
)
s
*=
2
transformer_depth
=
t_in
transformer_depth_output
=
t_out
new_config
[
"transformer_depth"
]
=
t_in
new_config
[
"transformer_depth_output"
]
=
t_out
new_config
[
"transformer_depth_middle"
]
=
transformer_depth_middle
new_config
[
"num_res_blocks"
]
=
num_res_blocks
return
new_config
def
unet_config_from_diffusers_unet
(
state_dict
,
dtype
):
match
=
{}
attention_resolutions
=
[]
...
...
@@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
matches
=
False
break
if
matches
:
return
unet_config
return
convert_config
(
unet_config
)
return
None
def
model_config_from_diffusers_unet
(
state_dict
,
dtype
):
...
...
comfy/sd.py
View file @
6ec3f12c
...
...
@@ -360,7 +360,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
from
.
import
latent_formats
model_config
.
latent_format
=
latent_formats
.
SD15
(
scale_factor
=
scale_factor
)
model_config
.
unet_config
=
unet_config
model_config
.
unet_config
=
model_detection
.
convert_config
(
unet_config
)
if
config
[
'model'
][
"target"
].
endswith
(
"ImageEmbeddingConditionedLatentDiffusion"
):
model
=
model_base
.
SD21UNCLIP
(
model_config
,
noise_aug_config
[
"params"
],
model_type
=
model_type
)
...
...
comfy/supported_models.py
View file @
6ec3f12c
...
...
@@ -104,7 +104,7 @@ class SDXLRefiner(supported_models_base.BASE):
"use_linear_in_transformer"
:
True
,
"context_dim"
:
1280
,
"adm_in_channels"
:
2560
,
"transformer_depth"
:
[
0
,
4
,
4
,
0
],
"transformer_depth"
:
[
0
,
0
,
4
,
4
,
4
,
4
,
0
,
0
],
}
latent_format
=
latent_formats
.
SDXL
...
...
@@ -139,7 +139,7 @@ class SDXL(supported_models_base.BASE):
unet_config
=
{
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
2
,
10
],
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
10
,
10
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
}
...
...
@@ -165,6 +165,7 @@ class SDXL(supported_models_base.BASE):
replace_prefix
[
"conditioner.embedders.0.transformer.text_model"
]
=
"cond_stage_model.clip_l.transformer.text_model"
state_dict
=
utils
.
transformers_convert
(
state_dict
,
"conditioner.embedders.1.model."
,
"cond_stage_model.clip_g.transformer.text_model."
,
32
)
keys_to_replace
[
"conditioner.embedders.1.model.text_projection"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.text_projection.weight"
]
=
"cond_stage_model.clip_g.text_projection"
keys_to_replace
[
"conditioner.embedders.1.model.logit_scale"
]
=
"cond_stage_model.clip_g.logit_scale"
state_dict
=
utils
.
state_dict_prefix_replace
(
state_dict
,
replace_prefix
)
...
...
@@ -189,5 +190,14 @@ class SDXL(supported_models_base.BASE):
def
clip_target
(
self
):
return
supported_models_base
.
ClipTarget
(
sdxl_clip
.
SDXLTokenizer
,
sdxl_clip
.
SDXLClipModel
)
class
SSD1B
(
SDXL
):
unet_config
=
{
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
0
,
2
,
2
,
4
,
4
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
}
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
]
models
=
[
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
]
comfy/utils.py
View file @
6ec3f12c
...
...
@@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
def
unet_to_diffusers
(
unet_config
):
num_res_blocks
=
unet_config
[
"num_res_blocks"
]
attention_resolutions
=
unet_config
[
"attention_resolutions"
]
channel_mult
=
unet_config
[
"channel_mult"
]
transformer_depth
=
unet_config
[
"transformer_depth"
]
transformer_depth
=
unet_config
[
"transformer_depth"
][:]
transformer_depth_output
=
unet_config
[
"transformer_depth_output"
][:]
num_blocks
=
len
(
channel_mult
)
if
isinstance
(
num_res_blocks
,
int
):
num_res_blocks
=
[
num_res_blocks
]
*
num_blocks
if
isinstance
(
transformer_depth
,
int
):
transformer_depth
=
[
transformer_depth
]
*
num_blocks
transformers_per_layer
=
[]
res
=
1
for
i
in
range
(
num_blocks
):
transformers
=
0
if
res
in
attention_resolutions
:
transformers
=
transformer_depth
[
i
]
transformers_per_layer
.
append
(
transformers
)
res
*=
2
transformers_mid
=
unet_config
.
get
(
"transformer_depth_middle"
,
transformer_depth
[
-
1
])
transformers_mid
=
unet_config
.
get
(
"transformer_depth_middle"
,
None
)
diffusers_unet_map
=
{}
for
x
in
range
(
num_blocks
):
...
...
@@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
for
i
in
range
(
num_res_blocks
[
x
]):
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"down_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"input_blocks.{}.0.{}"
.
format
(
n
,
b
)
if
transformers_per_layer
[
x
]
>
0
:
num_transformers
=
transformer_depth
.
pop
(
0
)
if
num_transformers
>
0
:
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"input_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers
_per_layer
[
x
]
):
for
t
in
range
(
num_
transformers
):
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"down_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"input_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
n
+=
1
...
...
@@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
diffusers_unet_map
[
"mid_block.resnets.{}.{}"
.
format
(
i
,
UNET_MAP_RESNET
[
b
])]
=
"middle_block.{}.{}"
.
format
(
n
,
b
)
num_res_blocks
=
list
(
reversed
(
num_res_blocks
))
transformers_per_layer
=
list
(
reversed
(
transformers_per_layer
))
for
x
in
range
(
num_blocks
):
n
=
(
num_res_blocks
[
x
]
+
1
)
*
x
l
=
num_res_blocks
[
x
]
+
1
...
...
@@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
for
b
in
UNET_MAP_RESNET
:
diffusers_unet_map
[
"up_blocks.{}.resnets.{}.{}"
.
format
(
x
,
i
,
UNET_MAP_RESNET
[
b
])]
=
"output_blocks.{}.0.{}"
.
format
(
n
,
b
)
c
+=
1
if
transformers_per_layer
[
x
]
>
0
:
num_transformers
=
transformer_depth_output
.
pop
()
if
num_transformers
>
0
:
c
+=
1
for
b
in
UNET_MAP_ATTENTIONS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.{}"
.
format
(
x
,
i
,
b
)]
=
"output_blocks.{}.1.{}"
.
format
(
n
,
b
)
for
t
in
range
(
transformers
_per_layer
[
x
]
):
for
t
in
range
(
num_
transformers
):
for
b
in
TRANSFORMER_BLOCKS
:
diffusers_unet_map
[
"up_blocks.{}.attentions.{}.transformer_blocks.{}.{}"
.
format
(
x
,
i
,
t
,
b
)]
=
"output_blocks.{}.1.transformer_blocks.{}.{}"
.
format
(
n
,
t
,
b
)
if
i
==
l
-
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment