Unverified Commit 98c5e5da authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

Attention processor cross attention norm group norm (#3021)

add group norm type to attention processor cross attention norm

This lets the cross attention norm use both a group norm block and a
layer norm block.

The group norm operates along the channels dimension
and requires input shape (batch size, channels, *) where as the layer norm with a single
`normalized_shape` dimension only operates over the least significant
dimension i.e. (*, channels).

The channels we want to normalize are the hidden dimension of the encoder hidden states.

By convention, the encoder hidden states are always passed as (batch size, sequence
length, hidden states).

This means the layer norm can operate on the tensor without modification, but the group
norm requires flipping the last two dimensions to operate on (batch size, hidden states, sequence length).

All existing attention processors will have the same logic and we can
consolidate it in a helper function `prepare_encoder_hidden_states`

prepare_encoder_hidden_states -> norm_encoder_hidden_states re: @patrickvonplaten

move norm_cross defined check to outside norm_encoder_hidden_states

add missing attn.norm_cross check
parent 2d52e81c
...@@ -56,7 +56,8 @@ class Attention(nn.Module): ...@@ -56,7 +56,8 @@ class Attention(nn.Module):
bias=False, bias=False,
upcast_attention: bool = False, upcast_attention: bool = False,
upcast_softmax: bool = False, upcast_softmax: bool = False,
cross_attention_norm: bool = False, cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None, norm_num_groups: Optional[int] = None,
out_bias: bool = True, out_bias: bool = True,
...@@ -69,7 +70,6 @@ class Attention(nn.Module): ...@@ -69,7 +70,6 @@ class Attention(nn.Module):
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm
self.scale = dim_head**-0.5 if scale_qk else 1.0 self.scale = dim_head**-0.5 if scale_qk else 1.0
...@@ -92,8 +92,28 @@ class Attention(nn.Module): ...@@ -92,8 +92,28 @@ class Attention(nn.Module):
else: else:
self.group_norm = None self.group_norm = None
if cross_attention_norm: if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
self.norm_cross = nn.LayerNorm(cross_attention_dim) self.norm_cross = nn.LayerNorm(cross_attention_dim)
elif cross_attention_norm == "group_norm":
if self.added_kv_proj_dim is not None:
# The given `encoder_hidden_states` are initially of shape
# (batch_size, seq_len, added_kv_proj_dim) before being projected
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
# before the projection, so we need to use `added_kv_proj_dim` as
# the number of channels for the group norm.
norm_cross_num_channels = added_kv_proj_dim
else:
norm_cross_num_channels = cross_attention_dim
self.norm_cross = nn.GroupNorm(
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
)
else:
raise ValueError(
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
...@@ -304,6 +324,25 @@ class Attention(nn.Module): ...@@ -304,6 +324,25 @@ class Attention(nn.Module):
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
return attention_mask return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states):
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
# Group norm norms along the channels dimension and expects
# input to be in the shape of (N, C, *). In this case, we want
# to norm along the hidden dimension, so we need to move
# (batch_size, sequence_length, hidden_size) ->
# (batch_size, hidden_size, sequence_length)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
else:
assert False
return encoder_hidden_states
class AttnProcessor: class AttnProcessor:
def __call__( def __call__(
...@@ -321,8 +360,8 @@ class AttnProcessor: ...@@ -321,8 +360,8 @@ class AttnProcessor:
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.cross_attention_norm: elif attn.norm_cross:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -388,7 +427,10 @@ class LoRAAttnProcessor(nn.Module): ...@@ -388,7 +427,10 @@ class LoRAAttnProcessor(nn.Module):
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
...@@ -416,6 +458,11 @@ class AttnAddedKVProcessor: ...@@ -416,6 +458,11 @@ class AttnAddedKVProcessor:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
...@@ -467,8 +514,8 @@ class XFormersAttnProcessor: ...@@ -467,8 +514,8 @@ class XFormersAttnProcessor:
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.cross_attention_norm: elif attn.norm_cross:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -511,8 +558,8 @@ class AttnProcessor2_0: ...@@ -511,8 +558,8 @@ class AttnProcessor2_0:
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.cross_attention_norm: elif attn.norm_cross:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -561,7 +608,10 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -561,7 +608,10 @@ class LoRAXFormersAttnProcessor(nn.Module):
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
...@@ -598,8 +648,8 @@ class SlicedAttnProcessor: ...@@ -598,8 +648,8 @@ class SlicedAttnProcessor:
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.cross_attention_norm: elif attn.norm_cross:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
...@@ -647,6 +697,11 @@ class SlicedAttnAddedKVProcessor: ...@@ -647,6 +697,11 @@ class SlicedAttnAddedKVProcessor:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states) query = attn.to_q(hidden_states)
......
...@@ -44,6 +44,7 @@ def get_down_block( ...@@ -44,6 +44,7 @@ def get_down_block(
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False, resnet_skip_time_act=False,
resnet_out_scale_factor=1.0, resnet_out_scale_factor=1.0,
cross_attention_norm=None,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D": if down_block_type == "DownBlock2D":
...@@ -126,6 +127,7 @@ def get_down_block( ...@@ -126,6 +127,7 @@ def get_down_block(
skip_time_act=resnet_skip_time_act, skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor, output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
) )
elif down_block_type == "SkipDownBlock2D": elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D( return SkipDownBlock2D(
...@@ -223,6 +225,7 @@ def get_up_block( ...@@ -223,6 +225,7 @@ def get_up_block(
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
resnet_skip_time_act=False, resnet_skip_time_act=False,
resnet_out_scale_factor=1.0, resnet_out_scale_factor=1.0,
cross_attention_norm=None,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D": if up_block_type == "UpBlock2D":
...@@ -293,6 +296,7 @@ def get_up_block( ...@@ -293,6 +296,7 @@ def get_up_block(
skip_time_act=resnet_skip_time_act, skip_time_act=resnet_skip_time_act,
output_scale_factor=resnet_out_scale_factor, output_scale_factor=resnet_out_scale_factor,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
) )
elif up_block_type == "AttnUpBlock2D": elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D( return AttnUpBlock2D(
...@@ -578,6 +582,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -578,6 +582,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
skip_time_act=False, skip_time_act=False,
only_cross_attention=False, only_cross_attention=False,
cross_attention_norm=None,
): ):
super().__init__() super().__init__()
...@@ -618,6 +623,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): ...@@ -618,6 +623,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
...@@ -1361,6 +1367,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1361,6 +1367,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
add_downsample=True, add_downsample=True,
skip_time_act=False, skip_time_act=False,
only_cross_attention=False, only_cross_attention=False,
cross_attention_norm=None,
): ):
super().__init__() super().__init__()
...@@ -1400,6 +1407,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module): ...@@ -1400,6 +1407,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
...@@ -1580,7 +1588,7 @@ class KCrossAttnDownBlock2D(nn.Module): ...@@ -1580,7 +1588,7 @@ class KCrossAttnDownBlock2D(nn.Module):
temb_channels=temb_channels, temb_channels=temb_channels,
attention_bias=True, attention_bias=True,
add_self_attention=add_self_attention, add_self_attention=add_self_attention,
cross_attention_norm=True, cross_attention_norm="layer_norm",
group_size=resnet_group_size, group_size=resnet_group_size,
) )
) )
...@@ -2361,6 +2369,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2361,6 +2369,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
add_upsample=True, add_upsample=True,
skip_time_act=False, skip_time_act=False,
only_cross_attention=False, only_cross_attention=False,
cross_attention_norm=None,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -2401,6 +2410,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module): ...@@ -2401,6 +2410,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
...@@ -2608,7 +2618,7 @@ class KCrossAttnUpBlock2D(nn.Module): ...@@ -2608,7 +2618,7 @@ class KCrossAttnUpBlock2D(nn.Module):
temb_channels=temb_channels, temb_channels=temb_channels,
attention_bias=True, attention_bias=True,
add_self_attention=add_self_attention, add_self_attention=add_self_attention,
cross_attention_norm=True, cross_attention_norm="layer_norm",
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
) )
) )
...@@ -2703,7 +2713,7 @@ class KAttentionBlock(nn.Module): ...@@ -2703,7 +2713,7 @@ class KAttentionBlock(nn.Module):
upcast_attention: bool = False, upcast_attention: bool = False,
temb_channels: int = 768, # for ada_group_norm temb_channels: int = 768, # for ada_group_norm
add_self_attention: bool = False, add_self_attention: bool = False,
cross_attention_norm: bool = False, cross_attention_norm: Optional[str] = None,
group_size: int = 32, group_size: int = 32,
): ):
super().__init__() super().__init__()
...@@ -2719,7 +2729,7 @@ class KAttentionBlock(nn.Module): ...@@ -2719,7 +2729,7 @@ class KAttentionBlock(nn.Module):
dropout=dropout, dropout=dropout,
bias=attention_bias, bias=attention_bias,
cross_attention_dim=None, cross_attention_dim=None,
cross_attention_norm=False, cross_attention_norm=None,
) )
# 2. Cross-Attn # 2. Cross-Attn
......
...@@ -169,6 +169,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -169,6 +169,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
projection_class_embeddings_input_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False, class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None, mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
...@@ -341,6 +342,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -341,6 +342,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -373,6 +375,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -373,6 +375,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act, skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention, only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
) )
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
...@@ -424,6 +427,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -424,6 +427,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
......
...@@ -243,8 +243,8 @@ class Pix2PixZeroAttnProcessor: ...@@ -243,8 +243,8 @@ class Pix2PixZeroAttnProcessor:
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.cross_attention_norm: elif attn.norm_cross:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
......
...@@ -65,8 +65,8 @@ class CrossAttnStoreProcessor: ...@@ -65,8 +65,8 @@ class CrossAttnStoreProcessor:
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.cross_attention_norm: elif attn.norm_cross:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)
......
...@@ -255,6 +255,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -255,6 +255,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
projection_class_embeddings_input_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None,
class_embeddings_concat: bool = False, class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None, mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
): ):
super().__init__() super().__init__()
...@@ -433,6 +434,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -433,6 +434,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -465,6 +467,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -465,6 +467,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act, skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention, only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
) )
elif mid_block_type is None: elif mid_block_type is None:
self.mid_block = None self.mid_block = None
...@@ -516,6 +519,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -516,6 +519,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act, resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor, resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -1511,6 +1515,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1511,6 +1515,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
cross_attention_dim=1280, cross_attention_dim=1280,
skip_time_act=False, skip_time_act=False,
only_cross_attention=False, only_cross_attention=False,
cross_attention_norm=None,
): ):
super().__init__() super().__init__()
...@@ -1551,6 +1556,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -1551,6 +1556,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
bias=True, bias=True,
upcast_softmax=True, upcast_softmax=True,
only_cross_attention=only_cross_attention, only_cross_attention=only_cross_attention,
cross_attention_norm=cross_attention_norm,
processor=AttnAddedKVProcessor(), processor=AttnAddedKVProcessor(),
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment