Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b3e97fc7
Commit
b3e97fc7
authored
Feb 28, 2024
by
comfyanonymous
Browse files
Koala 700M and 1B support.
Use the UNET Loader node to load the unet file to use them.
parent
37a86e46
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
27 deletions
+66
-27
comfy/ldm/modules/diffusionmodules/openaimodel.py
comfy/ldm/modules/diffusionmodules/openaimodel.py
+26
-22
comfy/model_detection.py
comfy/model_detection.py
+19
-4
comfy/supported_models.py
comfy/supported_models.py
+21
-1
No files found.
comfy/ldm/modules/diffusionmodules/openaimodel.py
View file @
b3e97fc7
...
...
@@ -708,27 +708,30 @@ class UNetModel(nn.Module):
device
=
device
,
operations
=
operations
)]
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
get_attention_layer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_checkpoint
=
use_checkpoint
),
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
None
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)]
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
middle_block
=
None
if
transformer_depth_middle
>=
-
1
:
if
transformer_depth_middle
>=
0
:
mid_block
+=
[
get_attention_layer
(
# always uses a self-attn
ch
,
num_heads
,
dim_head
,
depth
=
transformer_depth_middle
,
context_dim
=
context_dim
,
disable_self_attn
=
disable_middle_self_attn
,
use_checkpoint
=
use_checkpoint
),
get_resblock
(
merge_factor
=
merge_factor
,
merge_strategy
=
merge_strategy
,
video_kernel_size
=
video_kernel_size
,
ch
=
ch
,
time_embed_dim
=
time_embed_dim
,
dropout
=
dropout
,
out_channels
=
None
,
dims
=
dims
,
use_checkpoint
=
use_checkpoint
,
use_scale_shift_norm
=
use_scale_shift_norm
,
dtype
=
self
.
dtype
,
device
=
device
,
operations
=
operations
)]
self
.
middle_block
=
TimestepEmbedSequential
(
*
mid_block
)
self
.
_feature_size
+=
ch
self
.
output_blocks
=
nn
.
ModuleList
([])
...
...
@@ -858,7 +861,8 @@ class UNetModel(nn.Module):
h
=
p
(
h
,
transformer_options
)
transformer_options
[
"block"
]
=
(
"middle"
,
0
)
h
=
forward_timestep_embed
(
self
.
middle_block
,
h
,
emb
,
context
,
transformer_options
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
if
self
.
middle_block
is
not
None
:
h
=
forward_timestep_embed
(
self
.
middle_block
,
h
,
emb
,
context
,
transformer_options
,
time_context
=
time_context
,
num_video_frames
=
num_video_frames
,
image_only_indicator
=
image_only_indicator
)
h
=
apply_control
(
h
,
control
,
'middle'
)
...
...
comfy/model_detection.py
View file @
b3e97fc7
...
...
@@ -151,8 +151,10 @@ def detect_unet_config(state_dict, key_prefix):
channel_mult
.
append
(
last_channel_mult
)
if
"{}middle_block.1.proj_in.weight"
.
format
(
key_prefix
)
in
state_dict_keys
:
transformer_depth_middle
=
count_blocks
(
state_dict_keys
,
'{}middle_block.1.transformer_blocks.'
.
format
(
key_prefix
)
+
'{}'
)
els
e
:
el
if
"{}middle_block.0.in_layers.0.weight"
.
format
(
key_prefix
)
in
state_dict_key
s
:
transformer_depth_middle
=
-
1
else
:
transformer_depth_middle
=
-
2
unet_config
[
"in_channels"
]
=
in_channels
unet_config
[
"out_channels"
]
=
out_channels
...
...
@@ -242,6 +244,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
down_blocks
=
count_blocks
(
state_dict
,
"down_blocks.{}"
)
for
i
in
range
(
down_blocks
):
attn_blocks
=
count_blocks
(
state_dict
,
"down_blocks.{}.attentions."
.
format
(
i
)
+
'{}'
)
res_blocks
=
count_blocks
(
state_dict
,
"down_blocks.{}.resnets."
.
format
(
i
)
+
'{}'
)
for
ab
in
range
(
attn_blocks
):
transformer_count
=
count_blocks
(
state_dict
,
"down_blocks.{}.attentions.{}.transformer_blocks."
.
format
(
i
,
ab
)
+
'{}'
)
transformer_depth
.
append
(
transformer_count
)
...
...
@@ -250,8 +253,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
attn_res
*=
2
if
attn_blocks
==
0
:
transformer_depth
.
append
(
0
)
transformer_depth
.
append
(
0
)
for
i
in
range
(
res_blocks
):
transformer_depth
.
append
(
0
)
match
[
"transformer_depth"
]
=
transformer_depth
...
...
@@ -329,7 +332,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult'
:
[
1
,
2
,
4
],
'transformer_depth_middle'
:
-
1
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
2048
,
'num_head_channels'
:
64
,
'use_temporal_attention'
:
False
,
'use_temporal_resblock'
:
False
}
supported_models
=
[
SDXL
,
SDXL_refiner
,
SD21
,
SD15
,
SD21_uncliph
,
SD21_unclipl
,
SDXL_mid_cnet
,
SDXL_small_cnet
,
SDXL_diffusers_inpaint
,
SSD_1B
,
Segmind_Vega
]
KOALA_700M
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
2816
,
'dtype'
:
dtype
,
'in_channels'
:
4
,
'model_channels'
:
320
,
'num_res_blocks'
:
[
1
,
1
,
1
],
'transformer_depth'
:
[
0
,
2
,
5
],
'transformer_depth_output'
:
[
0
,
0
,
2
,
2
,
5
,
5
],
'channel_mult'
:
[
1
,
2
,
4
],
'transformer_depth_middle'
:
-
2
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
2048
,
'num_head_channels'
:
64
,
'use_temporal_attention'
:
False
,
'use_temporal_resblock'
:
False
}
KOALA_1B
=
{
'use_checkpoint'
:
False
,
'image_size'
:
32
,
'out_channels'
:
4
,
'use_spatial_transformer'
:
True
,
'legacy'
:
False
,
'num_classes'
:
'sequential'
,
'adm_in_channels'
:
2816
,
'dtype'
:
dtype
,
'in_channels'
:
4
,
'model_channels'
:
320
,
'num_res_blocks'
:
[
1
,
1
,
1
],
'transformer_depth'
:
[
0
,
2
,
6
],
'transformer_depth_output'
:
[
0
,
0
,
2
,
2
,
6
,
6
],
'channel_mult'
:
[
1
,
2
,
4
],
'transformer_depth_middle'
:
6
,
'use_linear_in_transformer'
:
True
,
'context_dim'
:
2048
,
'num_head_channels'
:
64
,
'use_temporal_attention'
:
False
,
'use_temporal_resblock'
:
False
}
supported_models
=
[
SDXL
,
SDXL_refiner
,
SD21
,
SD15
,
SD21_uncliph
,
SD21_unclipl
,
SDXL_mid_cnet
,
SDXL_small_cnet
,
SDXL_diffusers_inpaint
,
SSD_1B
,
Segmind_Vega
,
KOALA_700M
,
KOALA_1B
]
for
unet_config
in
supported_models
:
matches
=
True
...
...
comfy/supported_models.py
View file @
b3e97fc7
...
...
@@ -234,6 +234,26 @@ class Segmind_Vega(SDXL):
"use_temporal_attention"
:
False
,
}
class
KOALA_700M
(
SDXL
):
unet_config
=
{
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
2
,
5
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
class
KOALA_1B
(
SDXL
):
unet_config
=
{
"model_channels"
:
320
,
"use_linear_in_transformer"
:
True
,
"transformer_depth"
:
[
0
,
2
,
6
],
"context_dim"
:
2048
,
"adm_in_channels"
:
2816
,
"use_temporal_attention"
:
False
,
}
class
SVD_img2vid
(
supported_models_base
.
BASE
):
unet_config
=
{
"model_channels"
:
320
,
...
...
@@ -380,5 +400,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
return
out
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
,
Stable_Cascade_B
]
models
=
[
Stable_Zero123
,
SD15
,
SD20
,
SD21UnclipL
,
SD21UnclipH
,
SDXLRefiner
,
SDXL
,
SSD1B
,
KOALA_700M
,
KOALA_1B
,
Segmind_Vega
,
SD_X4Upscaler
,
Stable_Cascade_C
,
Stable_Cascade_B
]
models
+=
[
SVD_img2vid
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment