"...composable_kernel_onnx.git" did not exist on "8ee36118be9b19b15c2471bffeeeb624afb14044"
Commit b3e97fc7 authored by comfyanonymous's avatar comfyanonymous
Browse files

Koala 700M and 1B support.

Use the UNET Loader node to load the unet file to use them.
parent 37a86e46
...@@ -708,27 +708,30 @@ class UNetModel(nn.Module): ...@@ -708,27 +708,30 @@ class UNetModel(nn.Module):
device=device, device=device,
operations=operations operations=operations
)] )]
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn self.middle_block = None
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, if transformer_depth_middle >= -1:
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint if transformer_depth_middle >= 0:
), mid_block += [get_attention_layer( # always uses a self-attn
get_resblock( ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
merge_factor=merge_factor, disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
merge_strategy=merge_strategy, ),
video_kernel_size=video_kernel_size, get_resblock(
ch=ch, merge_factor=merge_factor,
time_embed_dim=time_embed_dim, merge_strategy=merge_strategy,
dropout=dropout, video_kernel_size=video_kernel_size,
out_channels=None, ch=ch,
dims=dims, time_embed_dim=time_embed_dim,
use_checkpoint=use_checkpoint, dropout=dropout,
use_scale_shift_norm=use_scale_shift_norm, out_channels=None,
dtype=self.dtype, dims=dims,
device=device, use_checkpoint=use_checkpoint,
operations=operations use_scale_shift_norm=use_scale_shift_norm,
)] dtype=self.dtype,
self.middle_block = TimestepEmbedSequential(*mid_block) device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch self._feature_size += ch
self.output_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([])
...@@ -858,7 +861,8 @@ class UNetModel(nn.Module): ...@@ -858,7 +861,8 @@ class UNetModel(nn.Module):
h = p(h, transformer_options) h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) if self.middle_block is not None:
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle') h = apply_control(h, control, 'middle')
......
...@@ -151,8 +151,10 @@ def detect_unet_config(state_dict, key_prefix): ...@@ -151,8 +151,10 @@ def detect_unet_config(state_dict, key_prefix):
channel_mult.append(last_channel_mult) channel_mult.append(last_channel_mult)
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
else: elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = -1 transformer_depth_middle = -1
else:
transformer_depth_middle = -2
unet_config["in_channels"] = in_channels unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels unet_config["out_channels"] = out_channels
...@@ -242,6 +244,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): ...@@ -242,6 +244,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
down_blocks = count_blocks(state_dict, "down_blocks.{}") down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks): for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks): for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count) transformer_depth.append(transformer_count)
...@@ -250,8 +253,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): ...@@ -250,8 +253,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
attn_res *= 2 attn_res *= 2
if attn_blocks == 0: if attn_blocks == 0:
transformer_depth.append(0) for i in range(res_blocks):
transformer_depth.append(0) transformer_depth.append(0)
match["transformer_depth"] = transformer_depth match["transformer_depth"] = transformer_depth
...@@ -329,7 +332,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): ...@@ -329,7 +332,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega] KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
for unet_config in supported_models: for unet_config in supported_models:
matches = True matches = True
......
...@@ -234,6 +234,26 @@ class Segmind_Vega(SDXL): ...@@ -234,6 +234,26 @@ class Segmind_Vega(SDXL):
"use_temporal_attention": False, "use_temporal_attention": False,
} }
class KOALA_700M(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 5],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class KOALA_1B(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 6],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE): class SVD_img2vid(supported_models_base.BASE):
unet_config = { unet_config = {
"model_channels": 320, "model_channels": 320,
...@@ -380,5 +400,5 @@ class Stable_Cascade_B(Stable_Cascade_C): ...@@ -380,5 +400,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
return out return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models += [SVD_img2vid] models += [SVD_img2vid]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment