from typing import Iterator import torch from lightx2v.models.video_encoders.hf.ltx2.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio from lightx2v.models.video_encoders.hf.ltx2.audio_vae.model_configurator import ( AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, VOCODER_COMFY_KEYS_FILTER, AudioDecoderConfigurator, AudioEncoderConfigurator, VocoderConfigurator, ) from lightx2v.models.video_encoders.hf.ltx2.audio_vae.vocoder import Vocoder from lightx2v.models.video_encoders.hf.ltx2.upsampler.model import LatentUpsamplerConfigurator from lightx2v.models.video_encoders.hf.ltx2.video_vae.model_configurator import ( VAE_DECODER_COMFY_KEYS_FILTER, VAE_ENCODER_COMFY_KEYS_FILTER, VideoDecoderConfigurator, VideoEncoderConfigurator, ) from lightx2v.models.video_encoders.hf.ltx2.video_vae.tiling import TilingConfig from lightx2v.models.video_encoders.hf.ltx2.video_vae.video_vae import VideoDecoder, VideoEncoder, decode_video from lightx2v.utils.ltx2_media_io import * from lightx2v.utils.ltx2_utils import * from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) class LTX2VideoVAE: def __init__( self, checkpoint_path: str, device: torch.device, dtype: torch.dtype = torch.bfloat16, load_encoder: bool = True, use_tiling: bool = False, cpu_offload: bool = False, ): self.checkpoint_path = checkpoint_path self.device = device self.dtype = dtype self.load_encoder_flag = load_encoder self.use_tiling = use_tiling self.loader = SafetensorsModelStateDictLoader() self.encoder = None self.decoder = None self.cpu_offload = cpu_offload self.grid_table = {} # Cache for 2D grid calculations self.load() def load(self) -> tuple[VideoEncoder | None, VideoDecoder | None]: config = self.loader.metadata(self.checkpoint_path) if self.load_encoder_flag: encoder = VideoEncoderConfigurator.from_config(config) state_dict_obj = self.loader.load( self.checkpoint_path, sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, device=self.device, ) state_dict = state_dict_obj.sd if self.dtype is not None: state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()} encoder.load_state_dict(state_dict, strict=False, assign=True) self.encoder = encoder.to(self.device).eval() decoder = VideoDecoderConfigurator.from_config(config) state_dict_obj = self.loader.load( self.checkpoint_path, sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, device=self.device, ) state_dict = state_dict_obj.sd if self.dtype is not None: state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()} decoder.load_state_dict(state_dict, strict=False, assign=True) self.decoder = decoder.to(self.device).eval() def encode(self, video_frames: torch.Tensor) -> torch.Tensor: """ Encode video frames to latent space. Args: video_frames: Input video tensor [1, C, T, H, W] or [C, T, H, W] Returns: Encoded latent tensor [C, F, H_latent, W_latent] """ # Ensure video has batch dimension if video_frames.dim() == 4: video_frames = video_frames.unsqueeze(0) if self.cpu_offload: self.encoder = self.encoder.to(AI_DEVICE) out = self.encoder(video_frames) if out.dim() == 5: out = out.squeeze(0) if self.cpu_offload: self.encoder = self.encoder.to("cpu") return out def decode( self, latent: torch.Tensor, tiling_config: TilingConfig | None = None, generator: torch.Generator | None = None, ) -> Iterator[torch.Tensor]: # 如果启用了tiling但没有提供配置,使用默认配置 if self.use_tiling and tiling_config is None: tiling_config = TilingConfig.default() if self.cpu_offload: self.decoder = self.decoder.to(AI_DEVICE) try: yield from decode_video(latent, self.decoder, tiling_config, generator) finally: if self.cpu_offload: self.decoder = self.decoder.to("cpu") class LTX2AudioVAE: def __init__( self, checkpoint_path: str, device: torch.device, dtype: torch.dtype = torch.bfloat16, cpu_offload: bool = False, ): self.checkpoint_path = checkpoint_path self.device = device self.dtype = dtype self.cpu_offload = cpu_offload self.loader = SafetensorsModelStateDictLoader() self.load() def load(self) -> tuple[AudioEncoder | None, AudioDecoder | None, Vocoder | None]: config = self.loader.metadata(self.checkpoint_path) encoder = AudioEncoderConfigurator.from_config(config) state_dict_obj = self.loader.load( self.checkpoint_path, sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, device=self.device, ) state_dict = state_dict_obj.sd if self.dtype is not None: state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()} encoder.load_state_dict(state_dict, strict=False, assign=True) self.encoder = encoder.to(self.device).eval() decoder = AudioDecoderConfigurator.from_config(config) state_dict_obj = self.loader.load( self.checkpoint_path, sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, device=self.device, ) state_dict = state_dict_obj.sd if self.dtype is not None: state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()} decoder.load_state_dict(state_dict, strict=False, assign=True) self.decoder = decoder.to(self.device).eval() vocoder = VocoderConfigurator.from_config(config) state_dict_obj = self.loader.load( self.checkpoint_path, sd_ops=VOCODER_COMFY_KEYS_FILTER, device=self.device, ) state_dict = state_dict_obj.sd if self.dtype is not None: state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()} vocoder.load_state_dict(state_dict, strict=False, assign=True) self.vocoder = vocoder.to(self.device).eval() return encoder, decoder, vocoder def encode(self, audio_spectrogram: torch.Tensor) -> torch.Tensor: if self.cpu_offload: self.encoder = self.encoder.to(AI_DEVICE) out = self.encoder(audio_spectrogram) if self.cpu_offload: self.encoder = self.encoder.to("cpu") return out def decode(self, latent: torch.Tensor) -> torch.Tensor: if self.cpu_offload: self.decoder = self.decoder.to(AI_DEVICE) self.vocoder = self.vocoder.to(AI_DEVICE) out = decode_audio(latent, self.decoder, self.vocoder) if self.cpu_offload: self.decoder = self.decoder.to("cpu") self.vocoder = self.vocoder.to("cpu") return out class LTX2Upsampler: """ Wrapper class for loading and using LatentUpsampler model, similar to LTX2VideoVAE. """ def __init__( self, checkpoint_path: str, device: torch.device, dtype: torch.dtype = torch.bfloat16, cpu_offload: bool = False, ): self.checkpoint_path = checkpoint_path self.device = device self.dtype = dtype self.cpu_offload = cpu_offload self.loader = None self.upsampler = None self.load() def load(self): """ Load upsampler model from checkpoint. Aligned exactly with Builder.build() in ltx_core.loader.single_gpu_model_builder: 1. Create model on meta device (aligned with Builder.meta_model) 2. Load state_dict and convert dtype if needed (aligned with Builder.build line 82-83) 3. Load state_dict with assign=True (aligned with Builder.build line 84) 4. Move to device only (aligned with Builder._return_model line 69) Key point: _return_model only does .to(device), NOT .to(dtype). This means we rely on assign=True to set correct dtype from state_dict. """ self.loader = SafetensorsModelStateDictLoader() config = self.loader.metadata(self.checkpoint_path) # Handle config format: may have rational_spatial_scale instead of spatial_scale if "rational_spatial_scale" in config and "spatial_scale" not in config: config["spatial_scale"] = config["rational_spatial_scale"] # Create model on meta device (aligned with Builder.meta_model line 47-48) with torch.device("meta"): upsampler = LatentUpsamplerConfigurator.from_config(config) # Load state_dict (aligned with Builder.load_sd) state_dict_obj = self.loader.load( self.checkpoint_path, sd_ops=None, # No key filtering, aligned with source code device=self.device, # Directly to target device (aligned with DummyRegistry case) ) state_dict = state_dict_obj.sd # Convert state_dict dtype if needed (aligned with Builder.build line 82-83) if self.dtype is not None: state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()} # Load state_dict with assign=True (aligned with Builder.build line 84) # assign=True directly replaces parameters, so dtype should match state_dict upsampler.load_state_dict(state_dict, strict=False, assign=True) # Move to device only (aligned with Builder._return_model line 69) # CRITICAL: _return_model only does .to(device), NOT .to(dtype) # This means we rely on assign=True to have set correct dtype from state_dict # If state_dict contains all parameters, they should already have correct dtype self.upsampler = upsampler.to(self.device).eval() return self.upsampler @torch.no_grad() def upsample( self, latent: torch.Tensor, video_encoder: VideoEncoder, ) -> torch.Tensor: """ Upsample video latent using the upsampler with proper normalization. Aligned with ltx_core.model.upsampler.model.upsample_video. This method directly calls the static upsample_video method to ensure exact alignment with source code implementation. Args: latent: Input latent tensor of shape [B, C, F, H, W] or [C, F, H, W]. video_encoder: VideoEncoder with per_channel_statistics for normalization. Returns: Upsampled latent tensor of shape [B, C, F, H*2, W*2] or [C, F, H*2, W*2]. """ if self.cpu_offload: self.upsampler = self.upsampler.to(AI_DEVICE) upsampled = self.upsample_video(latent, video_encoder, self.upsampler) if self.cpu_offload: self.upsampler = self.upsampler.to("cpu") return upsampled @staticmethod def upsample_video(latent: torch.Tensor, video_encoder: VideoEncoder, upsampler) -> torch.Tensor: """ Apply upsampling to the latent representation using the provided upsampler, with normalization and un-normalization based on the video encoder's per-channel statistics. This is a static method that can be used with any upsampler instance and video encoder. Aligned with ltx_core.model.upsampler.model.upsample_video. Args: latent: Input latent tensor of shape [B, C, F, H, W]. video_encoder: VideoEncoder with per_channel_statistics for normalization. upsampler: LatentUpsampler module to perform upsampling. Note: upsampler should already be in eval mode, on correct device, and with correct dtype. Returns: torch.Tensor: Upsampled and re-normalized latent tensor. """ # Aligned with source code: un_normalize -> upsampler -> normalize # Source code does not modify upsampler state, so we call it directly latent = video_encoder.per_channel_statistics.un_normalize(latent) latent = upsampler(latent) latent = video_encoder.per_channel_statistics.normalize(latent) return latent if __name__ == "__main__": dev = "cuda" dtype = torch.bfloat16 video_vae = LTX2VideoVAE( checkpoint_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", device=dev, dtype=dtype, ) audio_vae = LTX2AudioVAE( checkpoint_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors", device=dev, dtype=dtype, ) vid_enc = torch.load("/data/nvme0/gushiqiao/models/code/LightX2V/scripts/v.pth").unsqueeze(0) vid_dec = video_vae.decode(vid_enc) audio_enc = torch.load("/data/nvme0/gushiqiao/models/code/LightX2V/scripts/a.pth").unsqueeze(0) audio_dec = audio_vae.decode(audio_enc) encode_video( video=vid_dec, fps=24, audio=audio_dec, audio_sample_rate=24000, output_path=f"reconstructed_1.mp4", video_chunks_number=1, )