Unverified Commit 5a47442f authored by Cesaryuan's avatar Cesaryuan Committed by GitHub
Browse files

Fix: update type hints for Tuple parameters across multiple files to support...


Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)

* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 8f6328c4
...@@ -177,16 +177,21 @@ class UNet2DConditionModel( ...@@ -177,16 +177,21 @@ class UNet2DConditionModel(
center_input_sample: bool = False, center_input_sample: bool = False,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
...@@ -486,10 +491,10 @@ class UNet2DConditionModel( ...@@ -486,10 +491,10 @@ class UNet2DConditionModel(
def _check_config( def _check_config(
self, self,
down_block_types: Tuple[str], down_block_types: Tuple[str, ...],
up_block_types: Tuple[str], up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]], only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int], block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]], layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]], cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
......
...@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
groups: int = 32, groups: int = 32,
attention_head_dim: int = 64, attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3, layers_per_block: Union[int, Tuple[int]] = 3,
block_out_channels: Tuple[int] = (384, 768, 1536, 3072), block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096, cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096, encoder_hid_dim: int = 4096,
): ):
......
...@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL ...@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample_size: Optional[int] = None, sample_size: Optional[int] = None,
in_channels: int = 8, in_channels: int = 8,
out_channels: int = 4, out_channels: int = 4,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal", "DownBlockSpatioTemporal",
), ),
up_block_types: Tuple[str] = ( up_block_types: Tuple[str, ...] = (
"UpBlockSpatioTemporal", "UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal",
), ),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256, addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768, projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024, cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
num_frames: int = 25, num_frames: int = 25,
): ):
super().__init__() super().__init__()
......
...@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
timestep_ratio_embedding_dim: int = 64, timestep_ratio_embedding_dim: int = 64,
patch_size: int = 1, patch_size: int = 1,
conditioning_dim: int = 2048, conditioning_dim: int = 2048,
block_out_channels: Tuple[int] = (2048, 2048), block_out_channels: Tuple[int, ...] = (2048, 2048),
num_attention_heads: Tuple[int] = (32, 32), num_attention_heads: Tuple[int, ...] = (32, 32),
down_num_layers_per_block: Tuple[int] = (8, 24), down_num_layers_per_block: Tuple[int, ...] = (8, 24),
up_num_layers_per_block: Tuple[int] = (24, 8), up_num_layers_per_block: Tuple[int, ...] = (24, 8),
down_blocks_repeat_mappers: Optional[Tuple[int]] = ( down_blocks_repeat_mappers: Optional[Tuple[int]] = (
1, 1,
1, 1,
...@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
kernel_size=3, kernel_size=3,
dropout: Union[float, Tuple[float]] = (0.1, 0.1), dropout: Union[float, Tuple[float]] = (0.1, 0.1),
self_attn: Union[bool, Tuple[bool]] = True, self_attn: Union[bool, Tuple[bool]] = True,
timestep_conditioning_type: Tuple[str] = ("sca", "crp"), timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
switch_level: Optional[Tuple[bool]] = None, switch_level: Optional[Tuple[bool]] = None,
): ):
""" """
......
...@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module): ...@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int] = (64,) block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2 layers_per_block: int = 2
norm_num_groups: int = 32 norm_num_groups: int = 32
act_fn: str = "silu" act_fn: str = "silu"
...@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module): ...@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: int = (64,) block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2 layers_per_block: int = 2
norm_num_groups: int = 32 norm_num_groups: int = 32
act_fn: str = "silu" act_fn: str = "silu"
...@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels: int = 3 in_channels: int = 3
out_channels: int = 3 out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",) down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str] = ("UpDecoderBlock2D",) up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int] = (64,) block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 1 layers_per_block: int = 1
act_fn: str = "silu" act_fn: str = "silu"
latent_channels: int = 4 latent_channels: int = 4
......
...@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad ...@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
out_channels: int = 4, out_channels: int = 4,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
......
...@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
center_input_sample: bool = False, center_input_sample: bool = False,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat",
"DownBlockFlat", "DownBlockFlat",
), ),
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn", mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = ( up_block_types: Tuple[str, ...] = (
"UpBlockFlat", "UpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",
), ),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2, layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1, downsample_padding: int = 1,
mid_block_scale_factor: float = 1, mid_block_scale_factor: float = 1,
......
...@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin): ...@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
*, *,
param_names: Tuple[str] = ( param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight", "nerstf.mlp.0.weight",
"nerstf.mlp.1.weight", "nerstf.mlp.1.weight",
"nerstf.mlp.2.weight", "nerstf.mlp.2.weight",
...@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin): ...@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
def __init__( def __init__(
self, self,
*, *,
param_names: Tuple[str] = ( param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight", "nerstf.mlp.0.weight",
"nerstf.mlp.1.weight", "nerstf.mlp.1.weight",
"nerstf.mlp.2.weight", "nerstf.mlp.2.weight",
"nerstf.mlp.3.weight", "nerstf.mlp.3.weight",
), ),
param_shapes: Tuple[Tuple[int]] = ( param_shapes: Tuple[Tuple[int, int], ...] = (
(256, 93), (256, 93),
(256, 256), (256, 256),
(256, 256), (256, 256),
...@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin): ...@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
n_hidden_layers: int = 6, n_hidden_layers: int = 6,
act_fn: str = "swish", act_fn: str = "swish",
insert_direction_at: int = 4, insert_direction_at: int = 4,
background: Tuple[float] = ( background: Tuple[float, ...] = (
255.0, 255.0,
255.0, 255.0,
255.0, 255.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment