from dataclasses import dataclass, field from typing import Optional, Tuple from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig def is_blocks(n: str, m) -> bool: return "blocks" in n and str.isdigit(n.split(".")[-1]) @dataclass class WanVideoArchConfig(DiTArchConfig): _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) _param_names_mapping: dict = field( default_factory=lambda: { r"^patch_embedding\.(.*)$": r"patch_embedding.proj.\1", r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": r"condition_embedder.text_embedder.fc_in.\1", r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": r"condition_embedder.text_embedder.fc_out.\1", r"^condition_embedder\.time_embedder\.linear_1\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_in.\1", r"^condition_embedder\.time_embedder\.linear_2\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_out.\1", r"^condition_embedder\.time_proj\.(.*)$": r"condition_embedder.time_modulation.linear.\1", r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.(.*)$": r"condition_embedder.image_embedder.ff.fc_in.\1", r"^condition_embedder\.image_embedder\.ff\.net\.2\.(.*)$": r"condition_embedder.image_embedder.ff.fc_out.\1", r"^blocks\.(\d+)\.attn1\.to_q\.(.*)$": r"blocks.\1.to_q.\2", r"^blocks\.(\d+)\.attn1\.to_k\.(.*)$": r"blocks.\1.to_k.\2", r"^blocks\.(\d+)\.attn1\.to_v\.(.*)$": r"blocks.\1.to_v.\2", r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": r"blocks.\1.to_out.\2", r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2", r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2", r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2", r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2", r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", r"blocks\.(\d+)\.norm2\.(.*)$": r"blocks.\1.self_attn_residual_norm.norm.\2", }) patch_size: Tuple[int, int, int] = (1, 2, 2) text_len = 512 num_attention_heads: int = 40 attention_head_dim: int = 128 in_channels: int = 16 out_channels: int = 16 text_dim: int = 4096 freq_dim: int = 256 ffn_dim: int = 13824 num_layers: int = 40 cross_attn_norm: bool = True qk_norm: str = "rms_norm_across_heads" eps: float = 1e-6 image_dim: Optional[int] = None added_kv_proj_dim: Optional[int] = None rope_max_seq_len: int = 1024 def __post_init__(self): self.out_channels = self.out_channels or self.in_channels self.hidden_size = self.num_attention_heads * self.attention_head_dim self.num_channels_latents = self.in_channels if self.added_kv_proj_dim is None else self.out_channels @dataclass class WanVideoConfig(DiTConfig): arch_config: DiTArchConfig = field(default_factory=WanVideoArchConfig) prefix: str = "Wan"