"nndet/preprocessing/crop.py" did not exist on "ede95851422d2e71bf625d536798640627a11151"
Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
from dataclasses import dataclass, field
from typing import Optional, Tuple
import torch
from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig
def is_double_block(n: str, m) -> bool:
return "double" in n and str.isdigit(n.split(".")[-1])
def is_single_block(n: str, m) -> bool:
return "single" in n and str.isdigit(n.split(".")[-1])
def is_refiner_block(n: str, m) -> bool:
return "refiner" in n and str.isdigit(n.split(".")[-1])
@dataclass
class HunyuanVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(
default_factory=lambda:
[is_double_block, is_single_block, is_refiner_block])
_param_names_mapping: dict = field(
default_factory=lambda: {
# 1. context_embedder.time_text_embed submodules (specific rules, applied first):
r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_1\.(.*)$":
r"txt_in.t_embedder.mlp.fc_in.\1",
r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_2\.(.*)$":
r"txt_in.t_embedder.mlp.fc_out.\1",
r"^context_embedder\.proj_in\.(.*)$":
r"txt_in.input_embedder.\1",
r"^context_embedder\.time_text_embed\.text_embedder\.linear_1\.(.*)$":
r"txt_in.c_embedder.fc_in.\1",
r"^context_embedder\.time_text_embed\.text_embedder\.linear_2\.(.*)$":
r"txt_in.c_embedder.fc_out.\1",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm1\.(.*)$":
r"txt_in.refiner_blocks.\1.norm1.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm2\.(.*)$":
r"txt_in.refiner_blocks.\1.norm2.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_q\.(.*)$":
(r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", 0, 3),
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_k\.(.*)$":
(r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", 1, 3),
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_v\.(.*)$":
(r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", 2, 3),
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$":
r"txt_in.refiner_blocks.\1.self_attn_proj.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$":
r"txt_in.refiner_blocks.\1.mlp.fc_in.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$":
r"txt_in.refiner_blocks.\1.mlp.fc_out.\2",
r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm_out\.linear\.(.*)$":
r"txt_in.refiner_blocks.\1.adaLN_modulation.linear.\2",
# 3. x_embedder mapping:
r"^x_embedder\.proj\.(.*)$":
r"img_in.proj.\1",
# 4. Top-level time_text_embed mappings:
r"^time_text_embed\.timestep_embedder\.linear_1\.(.*)$":
r"time_in.mlp.fc_in.\1",
r"^time_text_embed\.timestep_embedder\.linear_2\.(.*)$":
r"time_in.mlp.fc_out.\1",
r"^time_text_embed\.guidance_embedder\.linear_1\.(.*)$":
r"guidance_in.mlp.fc_in.\1",
r"^time_text_embed\.guidance_embedder\.linear_2\.(.*)$":
r"guidance_in.mlp.fc_out.\1",
r"^time_text_embed\.text_embedder\.linear_1\.(.*)$":
r"vector_in.fc_in.\1",
r"^time_text_embed\.text_embedder\.linear_2\.(.*)$":
r"vector_in.fc_out.\1",
# 5. transformer_blocks mapping:
r"^transformer_blocks\.(\d+)\.norm1\.linear\.(.*)$":
r"double_blocks.\1.img_mod.linear.\2",
r"^transformer_blocks\.(\d+)\.norm1_context\.linear\.(.*)$":
r"double_blocks.\1.txt_mod.linear.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$":
r"double_blocks.\1.img_attn_q_norm.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$":
r"double_blocks.\1.img_attn_k_norm.\2",
r"^transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$":
(r"double_blocks.\1.img_attn_qkv.\2", 0, 3),
r"^transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$":
(r"double_blocks.\1.img_attn_qkv.\2", 1, 3),
r"^transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$":
(r"double_blocks.\1.img_attn_qkv.\2", 2, 3),
r"^transformer_blocks\.(\d+)\.attn\.add_q_proj\.(.*)$":
(r"double_blocks.\1.txt_attn_qkv.\2", 0, 3),
r"^transformer_blocks\.(\d+)\.attn\.add_k_proj\.(.*)$":
(r"double_blocks.\1.txt_attn_qkv.\2", 1, 3),
r"^transformer_blocks\.(\d+)\.attn\.add_v_proj\.(.*)$":
(r"double_blocks.\1.txt_attn_qkv.\2", 2, 3),
r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$":
r"double_blocks.\1.img_attn_proj.\2",
# Corrected: merge attn.to_add_out into the main projection.
r"^transformer_blocks\.(\d+)\.attn\.to_add_out\.(.*)$":
r"double_blocks.\1.txt_attn_proj.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_added_q\.(.*)$":
r"double_blocks.\1.txt_attn_q_norm.\2",
r"^transformer_blocks\.(\d+)\.attn\.norm_added_k\.(.*)$":
r"double_blocks.\1.txt_attn_k_norm.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$":
r"double_blocks.\1.img_mlp.fc_in.\2",
r"^transformer_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$":
r"double_blocks.\1.img_mlp.fc_out.\2",
r"^transformer_blocks\.(\d+)\.ff_context\.net\.0(?:\.proj)?\.(.*)$":
r"double_blocks.\1.txt_mlp.fc_in.\2",
r"^transformer_blocks\.(\d+)\.ff_context\.net\.2(?:\.proj)?\.(.*)$":
r"double_blocks.\1.txt_mlp.fc_out.\2",
# 6. single_transformer_blocks mapping:
r"^single_transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$":
r"single_blocks.\1.q_norm.\2",
r"^single_transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$":
r"single_blocks.\1.k_norm.\2",
r"^single_transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$":
(r"single_blocks.\1.linear1.\2", 0, 4),
r"^single_transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$":
(r"single_blocks.\1.linear1.\2", 1, 4),
r"^single_transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$":
(r"single_blocks.\1.linear1.\2", 2, 4),
r"^single_transformer_blocks\.(\d+)\.proj_mlp\.(.*)$":
(r"single_blocks.\1.linear1.\2", 3, 4),
# Corrected: map proj_out to modulation.linear rather than a separate proj_out branch.
r"^single_transformer_blocks\.(\d+)\.proj_out\.(.*)$":
r"single_blocks.\1.linear2.\2",
r"^single_transformer_blocks\.(\d+)\.norm\.linear\.(.*)$":
r"single_blocks.\1.modulation.linear.\2",
# 7. Final layers mapping:
r"^norm_out\.linear\.(.*)$":
r"final_layer.adaLN_modulation.linear.\1",
r"^proj_out\.(.*)$":
r"final_layer.linear.\1",
})
patch_size: int = 2
patch_size_t: int = 1
in_channels: int = 16
out_channels: int = 16
num_attention_heads: int = 24
attention_head_dim: int = 128
mlp_ratio: float = 4.0
num_layers: int = 20
num_single_layers: int = 40
num_refiner_layers: int = 2
rope_axes_dim: Tuple[int, int, int] = (16, 56, 56)
guidance_embeds: bool = False
dtype: Optional[torch.dtype] = None
text_embed_dim: int = 4096
pooled_projection_dim: int = 768
rope_theta: int = 256
qk_norm: str = "rms_norm"
def __post_init__(self):
self.hidden_size: int = self.attention_head_dim * self.num_attention_heads
self.num_channels_latents: int = self.in_channels
@dataclass
class HunyuanVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=HunyuanVideoArchConfig)
prefix: str = "Hunyuan"
from dataclasses import dataclass, field
from typing import Optional, Tuple
from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig
def is_blocks(n: str, m) -> bool:
return "blocks" in n and str.isdigit(n.split(".")[-1])
@dataclass
class WanVideoArchConfig(DiTArchConfig):
_fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks])
_param_names_mapping: dict = field(
default_factory=lambda: {
r"^patch_embedding\.(.*)$":
r"patch_embedding.proj.\1",
r"^condition_embedder\.text_embedder\.linear_1\.(.*)$":
r"condition_embedder.text_embedder.fc_in.\1",
r"^condition_embedder\.text_embedder\.linear_2\.(.*)$":
r"condition_embedder.text_embedder.fc_out.\1",
r"^condition_embedder\.time_embedder\.linear_1\.(.*)$":
r"condition_embedder.time_embedder.mlp.fc_in.\1",
r"^condition_embedder\.time_embedder\.linear_2\.(.*)$":
r"condition_embedder.time_embedder.mlp.fc_out.\1",
r"^condition_embedder\.time_proj\.(.*)$":
r"condition_embedder.time_modulation.linear.\1",
r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.(.*)$":
r"condition_embedder.image_embedder.ff.fc_in.\1",
r"^condition_embedder\.image_embedder\.ff\.net\.2\.(.*)$":
r"condition_embedder.image_embedder.ff.fc_out.\1",
r"^blocks\.(\d+)\.attn1\.to_q\.(.*)$":
r"blocks.\1.to_q.\2",
r"^blocks\.(\d+)\.attn1\.to_k\.(.*)$":
r"blocks.\1.to_k.\2",
r"^blocks\.(\d+)\.attn1\.to_v\.(.*)$":
r"blocks.\1.to_v.\2",
r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$":
r"blocks.\1.to_out.\2",
r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$":
r"blocks.\1.norm_q.\2",
r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$":
r"blocks.\1.norm_k.\2",
r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$":
r"blocks.\1.attn2.to_out.\2",
r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$":
r"blocks.\1.ffn.fc_in.\2",
r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$":
r"blocks.\1.ffn.fc_out.\2",
r"blocks\.(\d+)\.norm2\.(.*)$":
r"blocks.\1.self_attn_residual_norm.norm.\2",
})
patch_size: Tuple[int, int, int] = (1, 2, 2)
text_len = 512
num_attention_heads: int = 40
attention_head_dim: int = 128
in_channels: int = 16
out_channels: int = 16
text_dim: int = 4096
freq_dim: int = 256
ffn_dim: int = 13824
num_layers: int = 40
cross_attn_norm: bool = True
qk_norm: str = "rms_norm_across_heads"
eps: float = 1e-6
image_dim: Optional[int] = None
added_kv_proj_dim: Optional[int] = None
rope_max_seq_len: int = 1024
def __post_init__(self):
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.num_channels_latents = self.in_channels if self.added_kv_proj_dim is None else self.out_channels
@dataclass
class WanVideoConfig(DiTConfig):
arch_config: DiTArchConfig = field(default_factory=WanVideoArchConfig)
prefix: str = "Wan"
from fastvideo.v1.configs.models.encoders.base import (BaseEncoderOutput,
EncoderConfig,
ImageEncoderConfig,
TextEncoderConfig)
from fastvideo.v1.configs.models.encoders.clip import (CLIPTextConfig,
CLIPVisionConfig)
from fastvideo.v1.configs.models.encoders.llama import LlamaConfig
from fastvideo.v1.configs.models.encoders.t5 import T5Config
__all__ = [
"EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig",
"BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", "LlamaConfig",
"T5Config"
]
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import torch
from fastvideo.v1.configs.models.base import ArchConfig, ModelConfig
from fastvideo.v1.configs.quantization import QuantizationConfig
from fastvideo.v1.platforms import _Backend
@dataclass
class EncoderArchConfig(ArchConfig):
architectures: List[str] = field(default_factory=lambda: [])
_supported_attention_backends: Tuple[_Backend, ...] = (_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA)
output_hidden_states: bool = False
use_return_dict: bool = True
@dataclass
class TextEncoderArchConfig(EncoderArchConfig):
vocab_size: int = 0
hidden_size: int = 0
num_hidden_layers: int = 0
num_attention_heads: int = 0
pad_token_id: int = 0
eos_token_id: int = 0
text_len: int = 0
hidden_state_skip_layer: int = 0
decoder_start_token_id: int = 0
output_past: bool = True
scalable_attention: bool = True
tie_word_embeddings: bool = False
tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
self.tokenizer_kwargs = {
"truncation": True,
"max_length": self.text_len,
"return_tensors": "pt",
}
@dataclass
class ImageEncoderArchConfig(EncoderArchConfig):
pass
@dataclass
class BaseEncoderOutput:
last_hidden_state: Optional[torch.FloatTensor] = None
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
attention_mask: Optional[torch.Tensor] = None
@dataclass
class EncoderConfig(ModelConfig):
arch_config: ArchConfig = field(default_factory=EncoderArchConfig)
prefix: str = ""
quant_config: Optional[QuantizationConfig] = None
lora_config: Optional[Any] = None
@dataclass
class TextEncoderConfig(EncoderConfig):
arch_config: ArchConfig = field(default_factory=TextEncoderArchConfig)
@dataclass
class ImageEncoderConfig(EncoderConfig):
arch_config: ArchConfig = field(default_factory=ImageEncoderArchConfig)
from dataclasses import dataclass, field
from typing import Optional
from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig,
ImageEncoderConfig,
TextEncoderArchConfig,
TextEncoderConfig)
@dataclass
class CLIPTextArchConfig(TextEncoderArchConfig):
vocab_size: int = 49408
hidden_size: int = 512
intermediate_size: int = 2048
projection_dim: int = 512
num_hidden_layers: int = 12
num_attention_heads: int = 8
max_position_embeddings: int = 77
hidden_act: str = "quick_gelu"
layer_norm_eps: float = 1e-5
dropout: float = 0.0
attention_dropout: float = 0.0
initializer_range: float = 0.02
initializer_factor: float = 1.0
pad_token_id: int = 1
bos_token_id: int = 49406
eos_token_id: int = 49407
text_len: int = 77
@dataclass
class CLIPVisionArchConfig(ImageEncoderArchConfig):
hidden_size: int = 768
intermediate_size: int = 3072
projection_dim: int = 512
num_hidden_layers: int = 12
num_attention_heads: int = 12
num_channels: int = 3
image_size: int = 224
patch_size: int = 32
hidden_act: str = "quick_gelu"
layer_norm_eps: float = 1e-5
dropout: float = 0.0
attention_dropout: float = 0.0
initializer_range: float = 0.02
initializer_factor: float = 1.0
@dataclass
class CLIPTextConfig(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(
default_factory=CLIPTextArchConfig)
num_hidden_layers_override: Optional[int] = None
require_post_norm: Optional[bool] = None
prefix: str = "clip"
@dataclass
class CLIPVisionConfig(ImageEncoderConfig):
arch_config: ImageEncoderArchConfig = field(
default_factory=CLIPVisionArchConfig)
num_hidden_layers_override: Optional[int] = None
require_post_norm: Optional[bool] = None
prefix: str = "clip"
from dataclasses import dataclass, field
from typing import Optional
from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)
@dataclass
class LlamaArchConfig(TextEncoderArchConfig):
vocab_size: int = 32000
hidden_size: int = 4096
intermediate_size: int = 11008
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: Optional[int] = None
hidden_act: str = "silu"
max_position_embeddings: int = 2048
initializer_range: float = 0.02
rms_norm_eps: float = 1e-6
use_cache: bool = True
pad_token_id: int = 0
bos_token_id: int = 1
eos_token_id: int = 2
pretraining_tp: int = 1
tie_word_embeddings: bool = False
rope_theta: float = 10000.0
rope_scaling: Optional[float] = None
attention_bias: bool = False
attention_dropout: float = 0.0
mlp_bias: bool = False
head_dim: Optional[int] = None
hidden_state_skip_layer: int = 2
text_len: int = 256
@dataclass
class LlamaConfig(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig)
prefix: str = "llama"
from dataclasses import dataclass, field
from typing import Optional
from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig,
TextEncoderConfig)
@dataclass
class T5ArchConfig(TextEncoderArchConfig):
vocab_size: int = 32128
d_model: int = 512
d_kv: int = 64
d_ff: int = 2048
num_layers: int = 6
num_decoder_layers: Optional[int] = None
num_heads: int = 8
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
dropout_rate: float = 0.1
layer_norm_epsilon: float = 1e-6
initializer_factor: float = 1.0
feed_forward_proj: str = "relu"
dense_act_fn: str = ""
is_gated_act: bool = False
is_encoder_decoder: bool = True
use_cache: bool = True
pad_token_id: int = 0
eos_token_id: int = 1
classifier_dropout: float = 0.0
text_len: int = 512
# Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
def __post_init__(self):
super().__post_init__()
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn: str = act_info[-1]
self.is_gated_act: bool = act_info[0] == "gated"
if self.feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
self.tokenizer_kwargs = {
"padding": "max_length",
"truncation": True,
"max_length": self.text_len,
"add_special_tokens": True,
"return_attention_mask": True,
"return_tensors": "pt",
}
@dataclass
class T5Config(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig)
prefix: str = "t5"
from fastvideo.v1.configs.models.vaes.hunyuanvae import HunyuanVAEConfig
from fastvideo.v1.configs.models.vaes.wanvae import WanVAEConfig
__all__ = [
"HunyuanVAEConfig",
"WanVAEConfig",
]
from dataclasses import dataclass, field
from typing import Union
import torch
from fastvideo.v1.configs.models.base import ArchConfig, ModelConfig
@dataclass
class VAEArchConfig(ArchConfig):
scaling_factor: Union[float, torch.tensor] = 0
temporal_compression_ratio: int = 4
spatial_compression_ratio: int = 8
@dataclass
class VAEConfig(ModelConfig):
arch_config: VAEArchConfig = field(default_factory=VAEArchConfig)
# FastVideoVAE-specific parameters
load_encoder: bool = True
load_decoder: bool = True
tile_sample_min_height: int = 256
tile_sample_min_width: int = 256
tile_sample_min_num_frames: int = 16
tile_sample_stride_height: int = 192
tile_sample_stride_width: int = 192
tile_sample_stride_num_frames: int = 12
blend_num_frames: int = 0
use_tiling: bool = True
use_temporal_tiling: bool = True
use_parallel_tiling: bool = True
def __post_init__(self):
self.blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
from dataclasses import dataclass, field
from typing import Tuple
from fastvideo.v1.configs.models.vaes.base import VAEArchConfig, VAEConfig
@dataclass
class HunyuanVAEArchConfig(VAEArchConfig):
in_channels: int = 3
out_channels: int = 3
latent_channels: int = 16
down_block_types: Tuple[str, ...] = (
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
)
up_block_types: Tuple[str, ...] = (
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
)
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512)
layers_per_block: int = 2
act_fn: str = "silu"
norm_num_groups: int = 32
scaling_factor: float = 0.476986
spatial_compression_ratio: int = 8
temporal_compression_ratio: int = 4
mid_block_add_attention: bool = True
def __post_init__(self):
self.spatial_compression_ratio: int = 2**(len(self.block_out_channels) -
1)
@dataclass
class HunyuanVAEConfig(VAEConfig):
arch_config: VAEArchConfig = field(default_factory=HunyuanVAEArchConfig)
from dataclasses import dataclass, field
from typing import Tuple
import torch
from fastvideo.v1.configs.models.vaes.base import VAEArchConfig, VAEConfig
@dataclass
class WanVAEArchConfig(VAEArchConfig):
base_dim: int = 96
z_dim: int = 16
dim_mult: Tuple[int, ...] = (1, 2, 4, 4)
num_res_blocks: int = 2
attn_scales: Tuple[float, ...] = ()
temperal_downsample: Tuple[bool, ...] = (False, True, True)
dropout: float = 0.0
latents_mean: Tuple[float, ...] = (
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
)
latents_std: Tuple[float, ...] = (
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
)
temporal_compression_ratio = 4
spatial_compression_ratio = 8
def __post_init__(self):
self.scaling_factor: torch.tensor = 1.0 / torch.tensor(
self.latents_std).view(1, self.z_dim, 1, 1, 1)
self.shift_factor: torch.tensor = torch.tensor(self.latents_mean).view(
1, self.z_dim, 1, 1, 1)
@dataclass
class WanVAEConfig(VAEConfig):
arch_config: VAEArchConfig = field(default_factory=WanVAEArchConfig)
use_feature_cache: bool = True
use_tiling: bool = False
use_temporal_tiling: bool = False
use_parallel_tiling: bool = False
def __post_init__(self):
self.blend_num_frames = (self.tile_sample_min_num_frames -
self.tile_sample_stride_num_frames) * 2
from fastvideo.v1.configs.pipelines.base import (PipelineConfig,
SlidingTileAttnConfig)
from fastvideo.v1.configs.pipelines.hunyuan import (FastHunyuanConfig,
HunyuanConfig)
from fastvideo.v1.configs.pipelines.registry import (
get_pipeline_config_cls_for_name)
from fastvideo.v1.configs.pipelines.wan import (WanI2V480PConfig,
WanT2V480PConfig)
__all__ = [
"HunyuanConfig", "FastHunyuanConfig", "PipelineConfig",
"SlidingTileAttnConfig", "WanT2V480PConfig", "WanI2V480PConfig",
"get_pipeline_config_cls_for_name"
]
import json
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Callable, Dict, Optional, Tuple
import torch
from fastvideo.v1.configs.models import (DiTConfig, EncoderConfig, ModelConfig,
VAEConfig)
from fastvideo.v1.configs.models.encoders import BaseEncoderOutput
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import shallow_asdict
logger = init_logger(__name__)
def preprocess_text(prompt: str) -> str:
return prompt
def postprocess_text(output: BaseEncoderOutput) -> torch.tensor:
raise NotImplementedError
@dataclass
class PipelineConfig:
"""Base configuration for all pipeline architectures."""
# Video generation parameters
embedded_cfg_scale: float = 6.0
flow_shift: Optional[float] = None
use_cpu_offload: bool = False
disable_autocast: bool = False
# Model configuration
precision: str = "bf16"
# VAE configuration
vae_precision: str = "fp16"
vae_tiling: bool = True
vae_sp: bool = True
vae_config: VAEConfig = field(default_factory=VAEConfig)
# DiT configuration
dit_config: DiTConfig = field(default_factory=DiTConfig)
# Text encoder configuration
text_encoder_precisions: Tuple[str, ...] = field(
default_factory=lambda: ("fp16", ))
text_encoder_configs: Tuple[EncoderConfig, ...] = field(
default_factory=lambda: (EncoderConfig(), ))
preprocess_text_funcs: Tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (preprocess_text, ))
postprocess_text_funcs: Tuple[Callable[[BaseEncoderOutput], torch.tensor],
...] = field(default_factory=lambda:
(postprocess_text, ))
# STA (Spatial-Temporal Attention) parameters
mask_strategy_file_path: Optional[str] = None
enable_torch_compile: bool = False
@classmethod
def from_pretrained(cls, model_path: str) -> "PipelineConfig":
from fastvideo.v1.configs.pipelines.registry import (
get_pipeline_config_cls_for_name)
pipeline_config_cls = get_pipeline_config_cls_for_name(model_path)
if pipeline_config_cls is not None:
pipeline_config = pipeline_config_cls()
else:
logger.warning(
"Couldn't find an optimal sampling param for %s. Using the default sampling param.",
model_path)
pipeline_config = cls()
return pipeline_config
def dump_to_json(self, file_path: str):
output_dict = shallow_asdict(self)
del_keys = []
for key, value in output_dict.items():
if isinstance(value, ModelConfig):
model_dict = asdict(value)
# Model Arch Config should be hidden away from the users
model_dict.pop("arch_config")
output_dict[key] = model_dict
elif isinstance(value, tuple) and all(
isinstance(v, ModelConfig) for v in value):
model_dicts = []
for v in value:
model_dict = asdict(v)
# Model Arch Config should be hidden away from the users
model_dict.pop("arch_config")
model_dicts.append(model_dict)
output_dict[key] = model_dicts
elif isinstance(value, tuple) and all(callable(f) for f in value):
# Skip dumping functions
del_keys.append(key)
for key in del_keys:
output_dict.pop(key, None)
with open(file_path, "w") as f:
json.dump(output_dict, f, indent=2)
def load_from_json(self, file_path: str):
with open(file_path) as f:
input_pipeline_dict = json.load(f)
self.update_pipeline_config(input_pipeline_dict)
def update_pipeline_config(self, source_pipeline_dict: Dict[str,
Any]) -> None:
for f in fields(self):
key = f.name
if key in source_pipeline_dict:
current_value = getattr(self, key)
new_value = source_pipeline_dict[key]
# If it's a nested ModelConfig, update it recursively
if isinstance(current_value, ModelConfig):
current_value.update_model_config(new_value)
elif isinstance(current_value, tuple) and all(
isinstance(v, ModelConfig) for v in current_value):
assert len(current_value) == len(
new_value
), "Users shouldn't delete or add text encoder config objects in your json"
for target_config, source_config in zip(
current_value, new_value):
target_config.update_model_config(source_config)
else:
setattr(self, key, new_value)
if hasattr(self, "__post_init__"):
self.__post_init__()
@dataclass
class SlidingTileAttnConfig(PipelineConfig):
"""Configuration for sliding tile attention."""
# Override any BaseConfig defaults as needed
# Add sliding tile specific parameters
window_size: int = 16
stride: int = 8
# You can provide custom defaults for inherited fields
height: int = 576
width: int = 1024
# Additional configuration specific to sliding tile attention
pad_to_square: bool = False
use_overlap_optimization: bool = True
from dataclasses import dataclass, field
from typing import Callable, Tuple, TypedDict
import torch
from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig
from fastvideo.v1.configs.models.dits import HunyuanVideoConfig
from fastvideo.v1.configs.models.encoders import (BaseEncoderOutput,
CLIPTextConfig, LlamaConfig)
from fastvideo.v1.configs.models.vaes import HunyuanVAEConfig
from fastvideo.v1.configs.pipelines.base import PipelineConfig
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
class PromptTemplate(TypedDict):
template: str
crop_start: int
prompt_template_video: PromptTemplate = {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
}
def llama_preprocess_text(prompt: str) -> str:
return prompt_template_video["template"].format(prompt)
def llama_postprocess_text(outputs: BaseEncoderOutput) -> torch.tensor:
hidden_state_skip_layer = 2
assert outputs.hidden_states is not None
hidden_states: Tuple[torch.Tensor, ...] = outputs.hidden_states
last_hidden_state: torch.tensor = hidden_states[-(hidden_state_skip_layer +
1)]
crop_start = prompt_template_video.get("crop_start", -1)
last_hidden_state = last_hidden_state[:, crop_start:]
return last_hidden_state
def clip_preprocess_text(prompt: str) -> str:
return prompt
def clip_postprocess_text(outputs: BaseEncoderOutput) -> torch.tensor:
pooler_output: torch.tensor = outputs.pooler_output
return pooler_output
@dataclass
class HunyuanConfig(PipelineConfig):
"""Base configuration for HunYuan pipeline architecture."""
# HunyuanConfig-specific parameters with defaults
# DiT
dit_config: DiTConfig = field(default_factory=HunyuanVideoConfig)
# VAE
vae_config: VAEConfig = field(default_factory=HunyuanVAEConfig)
# Denoising stage
embedded_cfg_scale: int = 6
flow_shift: int = 7
# Text encoding stage
text_encoder_configs: Tuple[EncoderConfig, ...] = field(
default_factory=lambda: (LlamaConfig(), CLIPTextConfig()))
preprocess_text_funcs: Tuple[Callable[[str], str], ...] = field(
default_factory=lambda: (llama_preprocess_text, clip_preprocess_text))
postprocess_text_funcs: Tuple[
Callable[[BaseEncoderOutput], torch.tensor],
...] = field(default_factory=lambda:
(llama_postprocess_text, clip_postprocess_text))
# Precision for each component
precision: str = "bf16"
vae_precision: str = "fp16"
text_encoder_precisions: Tuple[str, ...] = field(
default_factory=lambda: ("fp16", "fp16"))
def __post_init__(self):
self.vae_config.load_encoder = False
self.vae_config.load_decoder = True
@dataclass
class FastHunyuanConfig(HunyuanConfig):
"""Configuration specifically optimized for FastHunyuan weights."""
# Override HunyuanConfig defaults
flow_shift: int = 17
# No need to re-specify guidance_scale or embedded_cfg_scale as they
# already have the desired values from HunyuanConfig
"""Registry for pipeline weight-specific configurations."""
import os
from typing import Callable, Dict, Optional, Type
from fastvideo.v1.configs.pipelines.base import PipelineConfig
from fastvideo.v1.configs.pipelines.hunyuan import (FastHunyuanConfig,
HunyuanConfig)
from fastvideo.v1.configs.pipelines.wan import (WanI2V480PConfig,
WanT2V480PConfig)
from fastvideo.v1.logger import init_logger
from fastvideo.v1.utils import (maybe_download_model_index,
verify_model_config_and_directory)
logger = init_logger(__name__)
# Registry maps specific model weights to their config classes
WEIGHT_CONFIG_REGISTRY: Dict[str, Type[PipelineConfig]] = {
"FastVideo/FastHunyuan-diffusers": FastHunyuanConfig,
"hunyuanvideo-community/HunyuanVideo": HunyuanConfig,
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V480PConfig,
"Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V480PConfig
# Add other specific weight variants
}
# For determining pipeline type from model ID
PIPELINE_DETECTOR: Dict[str, Callable[[str], bool]] = {
"hunyuan": lambda id: "hunyuan" in id.lower(),
"wanpipeline": lambda id: "wanpipeline" in id.lower(),
"wanimagetovideo": lambda id: "wanimagetovideo" in id.lower(),
# Add other pipeline architecture detectors
}
# Fallback configs when exact match isn't found but architecture is detected
PIPELINE_FALLBACK_CONFIG: Dict[str, Type[PipelineConfig]] = {
"hunyuan":
HunyuanConfig, # Base Hunyuan config as fallback for any Hunyuan variant
"wanpipeline":
WanT2V480PConfig, # Base Wan config as fallback for any Wan variant
"wanimagetovideo": WanI2V480PConfig,
# Other fallbacks by architecture
}
def get_pipeline_config_cls_for_name(
pipeline_name_or_path: str) -> Optional[type[PipelineConfig]]:
"""Get the appropriate config class for specific pretrained weights."""
if os.path.exists(pipeline_name_or_path):
config = verify_model_config_and_directory(pipeline_name_or_path)
logger.warning(
"FastVideo may not correctly identify the optimal config for this model, as the local directory may have been renamed."
)
else:
config = maybe_download_model_index(pipeline_name_or_path)
pipeline_name = config["_class_name"]
# First try exact match for specific weights
if pipeline_name_or_path in WEIGHT_CONFIG_REGISTRY:
return WEIGHT_CONFIG_REGISTRY[pipeline_name_or_path]
# Try partial matches (for local paths that might include the weight ID)
for registered_id, config_class in WEIGHT_CONFIG_REGISTRY.items():
if registered_id in pipeline_name_or_path:
return config_class
# If no match, try to use the fallback config
fallback_config = None
# Try to determine pipeline architecture for fallback
for pipeline_type, detector in PIPELINE_DETECTOR.items():
if detector(pipeline_name.lower()):
fallback_config = PIPELINE_FALLBACK_CONFIG.get(pipeline_type)
break
logger.warning("No match found for pipeline %s, using fallback config %s.",
pipeline_name_or_path, fallback_config)
return fallback_config
from dataclasses import dataclass, field
from typing import Callable, Tuple
import torch
from fastvideo.v1.configs.models import DiTConfig, EncoderConfig, VAEConfig
from fastvideo.v1.configs.models.dits import WanVideoConfig
from fastvideo.v1.configs.models.encoders import (BaseEncoderOutput,
CLIPVisionConfig, T5Config)
from fastvideo.v1.configs.models.vaes import WanVAEConfig
from fastvideo.v1.configs.pipelines.base import PipelineConfig
def t5_postprocess_text(outputs: BaseEncoderOutput) -> torch.tensor:
mask: torch.tensor = outputs.attention_mask
hidden_state: torch.tensor = outputs.last_hidden_state
seq_lens = mask.gt(0).sum(dim=1).long()
assert torch.isnan(hidden_state).sum() == 0
prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens)]
prompt_embeds_tensor: torch.tensor = torch.stack([
torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))])
for u in prompt_embeds
],
dim=0)
return prompt_embeds_tensor
@dataclass
class WanT2V480PConfig(PipelineConfig):
"""Base configuration for Wan T2V 1.3B pipeline architecture."""
# WanConfig-specific parameters with defaults
# DiT
dit_config: DiTConfig = field(default_factory=WanVideoConfig)
# VAE
vae_config: VAEConfig = field(default_factory=WanVAEConfig)
vae_tiling: bool = False
vae_sp: bool = False
# Video parameters
use_cpu_offload: bool = True
# Denoising stage
flow_shift: int = 3
# Text encoding stage
text_encoder_configs: Tuple[EncoderConfig, ...] = field(
default_factory=lambda: (T5Config(), ))
postprocess_text_funcs: Tuple[Callable[[BaseEncoderOutput], torch.tensor],
...] = field(default_factory=lambda:
(t5_postprocess_text, ))
# Precision for each component
precision: str = "bf16"
vae_precision: str = "fp16"
text_encoder_precisions: Tuple[str, ...] = field(
default_factory=lambda: ("fp32", ))
# WanConfig-specific added parameters
def __post_init__(self):
self.vae_config.load_encoder = False
self.vae_config.load_decoder = True
@dataclass
class WanI2V480PConfig(WanT2V480PConfig):
"""Base configuration for Wan I2V 14B 480P pipeline architecture."""
# WanConfig-specific parameters with defaults
# Precision for each component
image_encoder_config: EncoderConfig = field(
default_factory=CLIPVisionConfig)
image_encoder_precision: str = "fp32"
def __post_init__(self):
self.vae_config.load_encoder = True
self.vae_config.load_decoder = True
from fastvideo.v1.configs.quantization.base import QuantizationConfig
__all__ = ["QuantizationConfig"]
from dataclasses import dataclass
@dataclass
class QuantizationConfig:
pass
from fastvideo.v1.configs.sample.base import SamplingParam
__all__ = ["SamplingParam"]
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
@dataclass
class SamplingParam:
# All fields below are copied from ForwardBatch
data_type: str = "video"
# Image inputs
image_path: Optional[str] = None
# Text inputs
prompt: Optional[Union[str, List[str]]] = None
negative_prompt: Optional[str] = None
prompt_path: Optional[str] = None
output_path: str = "outputs/"
# Batch info
num_videos_per_prompt: int = 1
seed: int = 1024
# Original dimensions (before VAE scaling)
num_frames: int = 125
height: int = 720
width: int = 1280
fps: int = 24
# Denoising parameters
num_inference_steps: int = 50
guidance_scale: float = 1.0
guidance_rescale: float = 0.0
# Misc
save_video: bool = True
return_frames: bool = False
def __post_init__(self) -> None:
self.data_type = "video" if self.num_frames > 1 else "image"
def check_sampling_param(self):
if self.prompt_path and not self.prompt_path.endswith(".txt"):
raise ValueError("prompt_path must be a txt file")
def update(self, source_dict: Dict[str, Any]) -> None:
for key, value in source_dict.items():
if hasattr(self, key):
setattr(self, key, value)
else:
logger.exception("%s has no attribute %s",
type(self).__name__, key)
self.__post_init__()
@classmethod
def from_pretrained(cls, model_path: str) -> "SamplingParam":
from fastvideo.v1.configs.sample.registry import (
get_sampling_param_cls_for_name)
sampling_cls = get_sampling_param_cls_for_name(model_path)
if sampling_cls is not None:
sampling_param: SamplingParam = sampling_cls()
else:
logger.warning(
"Couldn't find an optimal sampling param for %s. Using the default sampling param.",
model_path)
sampling_param = cls()
return sampling_param
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment