Unverified Commit f07a16e0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

update unet2d (#1376)

* boom boom

* remove duplicate arg

* add use_linear_proj arg

* fix copies

* style

* add fast tests

* use_linear_proj -> use_linear_projection
parent 16a32c9d
...@@ -99,8 +99,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -99,8 +99,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_vector_embeds: Optional[int] = None, num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu", activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None, num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
): ):
super().__init__() super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
...@@ -126,6 +128,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -126,6 +128,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
...@@ -159,6 +164,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -159,6 +164,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 4. Define output layers # 4. Define output layers
if self.is_input_continuous: if self.is_input_continuous:
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim) self.norm_out = nn.LayerNorm(inner_dim)
...@@ -191,10 +199,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -191,10 +199,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
batch, channel, height, weight = hidden_states.shape batch, channel, height, weight = hidden_states.shape
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states) hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
...@@ -204,8 +220,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -204,8 +220,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output # 3. Output
if self.is_input_continuous: if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
output = hidden_states + residual output = hidden_states + residual
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states) hidden_states = self.norm_out(hidden_states)
......
...@@ -33,6 +33,7 @@ def get_down_block( ...@@ -33,6 +33,7 @@ def get_down_block(
cross_attention_dim=None, cross_attention_dim=None,
downsample_padding=None, downsample_padding=None,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
): ):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D": if down_block_type == "DownBlock2D":
...@@ -76,6 +77,7 @@ def get_down_block( ...@@ -76,6 +77,7 @@ def get_down_block(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
elif down_block_type == "SkipDownBlock2D": elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D( return SkipDownBlock2D(
...@@ -140,6 +142,7 @@ def get_up_block( ...@@ -140,6 +142,7 @@ def get_up_block(
resnet_groups=None, resnet_groups=None,
cross_attention_dim=None, cross_attention_dim=None,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
): ):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D": if up_block_type == "UpBlock2D":
...@@ -170,6 +173,7 @@ def get_up_block( ...@@ -170,6 +173,7 @@ def get_up_block(
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels, attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
elif up_block_type == "AttnUpBlock2D": elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D( return AttnUpBlock2D(
...@@ -327,6 +331,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -327,6 +331,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -362,6 +367,7 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -362,6 +367,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
) )
) )
else: else:
...@@ -523,6 +529,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -523,6 +529,7 @@ class CrossAttnDownBlock2D(nn.Module):
downsample_padding=1, downsample_padding=1,
add_downsample=True, add_downsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -556,6 +563,7 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -556,6 +563,7 @@ class CrossAttnDownBlock2D(nn.Module):
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
) )
) )
else: else:
...@@ -1120,6 +1128,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1120,6 +1128,7 @@ class CrossAttnUpBlock2D(nn.Module):
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -1155,6 +1164,7 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1155,6 +1164,7 @@ class CrossAttnUpBlock2D(nn.Module):
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
) )
) )
else: else:
......
...@@ -61,7 +61,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -61,7 +61,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
...@@ -106,8 +106,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -106,8 +106,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: int = 32, norm_num_groups: int = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
attention_head_dim: int = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False,
): ):
super().__init__() super().__init__()
...@@ -127,6 +128,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -127,6 +128,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.mid_block = None self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down # down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types): for i, down_block_type in enumerate(down_block_types):
...@@ -145,9 +149,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -145,9 +149,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -160,9 +165,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -160,9 +165,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
# count how many layers upsample the images # count how many layers upsample the images
...@@ -170,6 +176,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -170,6 +176,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
...@@ -197,8 +204,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -197,8 +204,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -256,8 +264,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -256,8 +264,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
Args: Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
(batch_size, sequence_length, hidden_size) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
......
...@@ -124,7 +124,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -124,7 +124,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`): flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding. Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
...@@ -174,8 +174,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -174,8 +174,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: int = 32, norm_num_groups: int = 32,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
cross_attention_dim: int = 1280, cross_attention_dim: int = 1280,
attention_head_dim: int = 8, attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False, dual_cross_attention: bool = False,
use_linear_projection: bool = False,
): ):
super().__init__() super().__init__()
...@@ -195,6 +196,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -195,6 +196,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.mid_block = None self.mid_block = None
self.up_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([])
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down # down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types): for i, down_block_type in enumerate(down_block_types):
...@@ -213,9 +217,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -213,9 +217,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -228,9 +233,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -228,9 +233,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor, output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default", resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
# count how many layers upsample the images # count how many layers upsample the images
...@@ -238,6 +244,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -238,6 +244,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
...@@ -265,8 +272,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -265,8 +272,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn, resnet_act_fn=act_fn,
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim, attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention, dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
...@@ -324,8 +332,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -324,8 +332,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
Args: Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
(batch_size, sequence_length, hidden_size) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
...@@ -640,6 +647,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -640,6 +647,7 @@ class CrossAttnDownBlockFlat(nn.Module):
downsample_padding=1, downsample_padding=1,
add_downsample=True, add_downsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -673,6 +681,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -673,6 +681,7 @@ class CrossAttnDownBlockFlat(nn.Module):
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
) )
) )
else: else:
...@@ -851,6 +860,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -851,6 +860,7 @@ class CrossAttnUpBlockFlat(nn.Module):
output_scale_factor=1.0, output_scale_factor=1.0,
add_upsample=True, add_upsample=True,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
...@@ -886,6 +896,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -886,6 +896,7 @@ class CrossAttnUpBlockFlat(nn.Module):
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
) )
) )
else: else:
...@@ -988,6 +999,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -988,6 +999,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
output_scale_factor=1.0, output_scale_factor=1.0,
cross_attention_dim=1280, cross_attention_dim=1280,
dual_cross_attention=False, dual_cross_attention=False,
use_linear_projection=False,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -1023,6 +1035,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1023,6 +1035,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
num_layers=1, num_layers=1,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups, norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
) )
) )
else: else:
......
...@@ -296,6 +296,44 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -296,6 +296,44 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
for name, param in named_params.items(): for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_use_linear_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["use_linear_projection"] = True
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment