Unverified Commit 5a47442f authored by Cesaryuan's avatar Cesaryuan Committed by GitHub
Browse files

Fix: update type hints for Tuple parameters across multiple files to support...


Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples (#12544)

* Fix: update type hints for Tuple parameters across multiple files to support variable-length tuples

* Apply style fixes

---------
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 8f6328c4
......@@ -177,16 +177,21 @@ class UNet2DConditionModel(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
......@@ -486,10 +491,10 @@ class UNet2DConditionModel(
def _check_config(
self,
down_block_types: Tuple[str],
up_block_types: Tuple[str],
down_block_types: Tuple[str, ...],
up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
block_out_channels: Tuple[int],
block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
......
......@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
groups: int = 32,
attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3,
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096,
):
......
......@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
num_frames: int = 25,
):
super().__init__()
......
......@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
timestep_ratio_embedding_dim: int = 64,
patch_size: int = 1,
conditioning_dim: int = 2048,
block_out_channels: Tuple[int] = (2048, 2048),
num_attention_heads: Tuple[int] = (32, 32),
down_num_layers_per_block: Tuple[int] = (8, 24),
up_num_layers_per_block: Tuple[int] = (24, 8),
block_out_channels: Tuple[int, ...] = (2048, 2048),
num_attention_heads: Tuple[int, ...] = (32, 32),
down_num_layers_per_block: Tuple[int, ...] = (8, 24),
up_num_layers_per_block: Tuple[int, ...] = (24, 8),
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
1,
1,
......@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
kernel_size=3,
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
self_attn: Union[bool, Tuple[bool]] = True,
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
switch_level: Optional[Tuple[bool]] = None,
):
"""
......
......@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
......@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: int = (64,)
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
......@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int] = (64,)
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 1
act_fn: str = "silu"
latent_channels: int = 4
......
......@@ -245,16 +245,21 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
out_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
......
......@@ -374,21 +374,21 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = (
up_block_types: Tuple[str, ...] = (
"UpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
......
......@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
def __init__(
self,
*,
param_names: Tuple[str] = (
param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
......@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
def __init__(
self,
*,
param_names: Tuple[str] = (
param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
"nerstf.mlp.3.weight",
),
param_shapes: Tuple[Tuple[int]] = (
param_shapes: Tuple[Tuple[int, int], ...] = (
(256, 93),
(256, 256),
(256, 256),
......@@ -804,7 +804,7 @@ class ShapERenderer(ModelMixin, ConfigMixin):
n_hidden_layers: int = 6,
act_fn: str = "swish",
insert_direction_at: int = 4,
background: Tuple[float] = (
background: Tuple[float, ...] = (
255.0,
255.0,
255.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment