Unverified Commit 16170c69 authored by Yongsen Mao's avatar Yongsen Mao Committed by GitHub
Browse files

add sd1.5 compatibility to controlnet-xs and fix unused_parameters error during training (#8606)

* add sd1.5 compatibility to controlnet-xs

* set use_linear_projection by base_block

* refine code style
parent 4408047a
...@@ -114,6 +114,7 @@ def get_down_block_adapter( ...@@ -114,6 +114,7 @@ def get_down_block_adapter(
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True, add_downsample: bool = True,
upcast_attention: Optional[bool] = False, upcast_attention: Optional[bool] = False,
use_linear_projection: Optional[bool] = True,
): ):
num_layers = 2 # only support sd + sdxl num_layers = 2 # only support sd + sdxl
...@@ -152,7 +153,7 @@ def get_down_block_adapter( ...@@ -152,7 +153,7 @@ def get_down_block_adapter(
in_channels=ctrl_out_channels, in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
) )
...@@ -200,6 +201,7 @@ def get_mid_block_adapter( ...@@ -200,6 +201,7 @@ def get_mid_block_adapter(
num_attention_heads: Optional[int] = 1, num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False, upcast_attention: bool = False,
use_linear_projection: bool = True,
): ):
# Before the midblock application, information is concatted from base to control. # Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels # Concat doesn't require change in number of channels
...@@ -214,7 +216,7 @@ def get_mid_block_adapter( ...@@ -214,7 +216,7 @@ def get_mid_block_adapter(
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
...@@ -308,6 +310,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): ...@@ -308,6 +310,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
transformer_layers_per_block: Union[int, Tuple[int]] = 1, transformer_layers_per_block: Union[int, Tuple[int]] = 1,
upcast_attention: bool = True, upcast_attention: bool = True,
max_norm_num_groups: int = 32, max_norm_num_groups: int = 32,
use_linear_projection: bool = True,
): ):
super().__init__() super().__init__()
...@@ -381,6 +384,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): ...@@ -381,6 +384,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim[i], cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
) )
...@@ -393,6 +397,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): ...@@ -393,6 +397,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
num_attention_heads=num_attention_heads[-1], num_attention_heads=num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# up # up
...@@ -489,6 +494,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin): ...@@ -489,6 +494,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
transformer_layers_per_block=unet.config.transformer_layers_per_block, transformer_layers_per_block=unet.config.transformer_layers_per_block,
upcast_attention=unet.config.upcast_attention, upcast_attention=unet.config.upcast_attention,
max_norm_num_groups=unet.config.norm_num_groups, max_norm_num_groups=unet.config.norm_num_groups,
use_linear_projection=unet.config.use_linear_projection,
) )
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
...@@ -538,6 +544,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -538,6 +544,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
addition_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None, addition_time_embed_dim: Optional[int] = None,
upcast_attention: bool = True, upcast_attention: bool = True,
use_linear_projection: bool = True,
time_cond_proj_dim: Optional[int] = None, time_cond_proj_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None,
# additional controlnet configs # additional controlnet configs
...@@ -595,7 +602,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -595,7 +602,12 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
time_embed_dim, time_embed_dim,
cond_proj_dim=time_cond_proj_dim, cond_proj_dim=time_cond_proj_dim,
) )
self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim) if ctrl_learn_time_embedding:
self.ctrl_time_embedding = TimestepEmbedding(
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
)
else:
self.ctrl_time_embedding = None
if addition_embed_type is None: if addition_embed_type is None:
self.base_add_time_proj = None self.base_add_time_proj = None
...@@ -632,6 +644,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -632,6 +644,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
cross_attention_dim=cross_attention_dim[i], cross_attention_dim=cross_attention_dim[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
) )
...@@ -647,6 +660,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -647,6 +660,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
ctrl_num_attention_heads=ctrl_num_attention_heads[-1], ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
cross_attention_dim=cross_attention_dim[-1], cross_attention_dim=cross_attention_dim[-1],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# # Create up blocks # # Create up blocks
...@@ -690,6 +704,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -690,6 +704,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
add_upsample=not is_final_block, add_upsample=not is_final_block,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
) )
) )
...@@ -754,6 +769,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -754,6 +769,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
"addition_embed_type", "addition_embed_type",
"addition_time_embed_dim", "addition_time_embed_dim",
"upcast_attention", "upcast_attention",
"use_linear_projection",
"time_cond_proj_dim", "time_cond_proj_dim",
"projection_class_embeddings_input_dim", "projection_class_embeddings_input_dim",
] ]
...@@ -1219,6 +1235,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1219,6 +1235,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
add_downsample: bool = True, add_downsample: bool = True,
upcast_attention: Optional[bool] = False, upcast_attention: Optional[bool] = False,
use_linear_projection: Optional[bool] = True,
): ):
super().__init__() super().__init__()
base_resnets = [] base_resnets = []
...@@ -1270,7 +1287,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1270,7 +1287,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
in_channels=base_out_channels, in_channels=base_out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
) )
...@@ -1282,7 +1299,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1282,7 +1299,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
in_channels=ctrl_out_channels, in_channels=ctrl_out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
) )
...@@ -1342,6 +1359,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1342,6 +1359,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
use_linear_projection = base_downblock.attentions[0].use_linear_projection
else: else:
has_crossattn = False has_crossattn = False
transformer_layers_per_block = None transformer_layers_per_block = None
...@@ -1349,6 +1367,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1349,6 +1367,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
ctrl_num_attention_heads = None ctrl_num_attention_heads = None
cross_attention_dim = None cross_attention_dim = None
upcast_attention = None upcast_attention = None
use_linear_projection = None
add_downsample = base_downblock.downsamplers is not None add_downsample = base_downblock.downsamplers is not None
# create model # create model
...@@ -1367,6 +1386,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module): ...@@ -1367,6 +1386,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
add_downsample=add_downsample, add_downsample=add_downsample,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# # load weights # # load weights
...@@ -1527,6 +1547,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ...@@ -1527,6 +1547,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads: Optional[int] = 1, ctrl_num_attention_heads: Optional[int] = 1,
cross_attention_dim: Optional[int] = 1024, cross_attention_dim: Optional[int] = 1024,
upcast_attention: bool = False, upcast_attention: bool = False,
use_linear_projection: Optional[bool] = True,
): ):
super().__init__() super().__init__()
...@@ -1541,7 +1562,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ...@@ -1541,7 +1562,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=base_num_attention_heads, num_attention_heads=base_num_attention_heads,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
...@@ -1556,7 +1577,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ...@@ -1556,7 +1577,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
), ),
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=ctrl_num_attention_heads, num_attention_heads=ctrl_num_attention_heads,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
...@@ -1590,6 +1611,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ...@@ -1590,6 +1611,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
use_linear_projection = base_midblock.attentions[0].use_linear_projection
# create model # create model
model = cls( model = cls(
...@@ -1603,6 +1625,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module): ...@@ -1603,6 +1625,7 @@ class ControlNetXSCrossAttnMidBlock2D(nn.Module):
ctrl_num_attention_heads=ctrl_num_attention_heads, ctrl_num_attention_heads=ctrl_num_attention_heads,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# load weights # load weights
...@@ -1677,6 +1700,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1677,6 +1700,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
cross_attention_dim: int = 1024, cross_attention_dim: int = 1024,
add_upsample: bool = True, add_upsample: bool = True,
upcast_attention: bool = False, upcast_attention: bool = False,
use_linear_projection: Optional[bool] = True,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1714,7 +1738,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1714,7 +1738,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
in_channels=out_channels, in_channels=out_channels,
num_layers=transformer_layers_per_block[i], num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
use_linear_projection=True, use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
) )
...@@ -1753,12 +1777,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1753,12 +1777,14 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
num_attention_heads = get_first_cross_attention(base_upblock).heads num_attention_heads = get_first_cross_attention(base_upblock).heads
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
use_linear_projection = base_upblock.attentions[0].use_linear_projection
else: else:
has_crossattn = False has_crossattn = False
transformer_layers_per_block = None transformer_layers_per_block = None
num_attention_heads = None num_attention_heads = None
cross_attention_dim = None cross_attention_dim = None
upcast_attention = None upcast_attention = None
use_linear_projection = None
add_upsample = base_upblock.upsamplers is not None add_upsample = base_upblock.upsamplers is not None
# create model # create model
...@@ -1776,6 +1802,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module): ...@@ -1776,6 +1802,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
add_upsample=add_upsample, add_upsample=add_upsample,
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
use_linear_projection=use_linear_projection,
) )
# load weights # load weights
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment