model.py 13.3 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from typing import Iterator

import torch

from lightx2v.models.video_encoders.hf.ltx2.audio_vae.audio_vae import AudioDecoder, AudioEncoder, decode_audio
from lightx2v.models.video_encoders.hf.ltx2.audio_vae.model_configurator import (
    AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
    AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
    VOCODER_COMFY_KEYS_FILTER,
    AudioDecoderConfigurator,
    AudioEncoderConfigurator,
    VocoderConfigurator,
)
from lightx2v.models.video_encoders.hf.ltx2.audio_vae.vocoder import Vocoder
from lightx2v.models.video_encoders.hf.ltx2.upsampler.model import LatentUpsamplerConfigurator
from lightx2v.models.video_encoders.hf.ltx2.video_vae.model_configurator import (
    VAE_DECODER_COMFY_KEYS_FILTER,
    VAE_ENCODER_COMFY_KEYS_FILTER,
    VideoDecoderConfigurator,
    VideoEncoderConfigurator,
)
from lightx2v.models.video_encoders.hf.ltx2.video_vae.tiling import TilingConfig
from lightx2v.models.video_encoders.hf.ltx2.video_vae.video_vae import VideoDecoder, VideoEncoder, decode_video
from lightx2v.utils.ltx2_media_io import *
from lightx2v.utils.ltx2_utils import *
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)


class LTX2VideoVAE:
    def __init__(
        self,
        checkpoint_path: str,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
        load_encoder: bool = True,
        use_tiling: bool = False,
        cpu_offload: bool = False,
    ):
        self.checkpoint_path = checkpoint_path
        self.device = device
        self.dtype = dtype
        self.load_encoder_flag = load_encoder
        self.use_tiling = use_tiling
        self.loader = SafetensorsModelStateDictLoader()
        self.encoder = None
        self.decoder = None
        self.cpu_offload = cpu_offload
        self.grid_table = {}  # Cache for 2D grid calculations
        self.load()

    def load(self) -> tuple[VideoEncoder | None, VideoDecoder | None]:
        config = self.loader.metadata(self.checkpoint_path)

        if self.load_encoder_flag:
            encoder = VideoEncoderConfigurator.from_config(config)
            state_dict_obj = self.loader.load(
                self.checkpoint_path,
                sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
                device=self.device,
            )
            state_dict = state_dict_obj.sd
            if self.dtype is not None:
                state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()}
            encoder.load_state_dict(state_dict, strict=False, assign=True)
            self.encoder = encoder.to(self.device).eval()

        decoder = VideoDecoderConfigurator.from_config(config)
        state_dict_obj = self.loader.load(
            self.checkpoint_path,
            sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
            device=self.device,
        )
        state_dict = state_dict_obj.sd
        if self.dtype is not None:
            state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()}
        decoder.load_state_dict(state_dict, strict=False, assign=True)
        self.decoder = decoder.to(self.device).eval()

    def encode(self, video_frames: torch.Tensor) -> torch.Tensor:
        """
        Encode video frames to latent space.
        Args:
            video_frames: Input video tensor [1, C, T, H, W] or [C, T, H, W]
        Returns:
            Encoded latent tensor [C, F, H_latent, W_latent]
        """
        # Ensure video has batch dimension
        if video_frames.dim() == 4:
            video_frames = video_frames.unsqueeze(0)

        if self.cpu_offload:
            self.encoder = self.encoder.to(AI_DEVICE)

        out = self.encoder(video_frames)
        if out.dim() == 5:
            out = out.squeeze(0)

        if self.cpu_offload:
            self.encoder = self.encoder.to("cpu")

        return out

    def decode(
        self,
        latent: torch.Tensor,
        tiling_config: TilingConfig | None = None,
        generator: torch.Generator | None = None,
    ) -> Iterator[torch.Tensor]:
        # 如果启用了tiling但没有提供配置,使用默认配置
        if self.use_tiling and tiling_config is None:
            tiling_config = TilingConfig.default()

        if self.cpu_offload:
            self.decoder = self.decoder.to(AI_DEVICE)
        try:
            yield from decode_video(latent, self.decoder, tiling_config, generator)
        finally:
            if self.cpu_offload:
                self.decoder = self.decoder.to("cpu")


class LTX2AudioVAE:
    def __init__(
        self,
        checkpoint_path: str,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
        cpu_offload: bool = False,
    ):
        self.checkpoint_path = checkpoint_path
        self.device = device
        self.dtype = dtype
        self.cpu_offload = cpu_offload
        self.loader = SafetensorsModelStateDictLoader()
        self.load()

    def load(self) -> tuple[AudioEncoder | None, AudioDecoder | None, Vocoder | None]:
        config = self.loader.metadata(self.checkpoint_path)

        encoder = AudioEncoderConfigurator.from_config(config)
        state_dict_obj = self.loader.load(
            self.checkpoint_path,
            sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
            device=self.device,
        )
        state_dict = state_dict_obj.sd
        if self.dtype is not None:
            state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()}
        encoder.load_state_dict(state_dict, strict=False, assign=True)
        self.encoder = encoder.to(self.device).eval()

        decoder = AudioDecoderConfigurator.from_config(config)
        state_dict_obj = self.loader.load(
            self.checkpoint_path,
            sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
            device=self.device,
        )
        state_dict = state_dict_obj.sd
        if self.dtype is not None:
            state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()}
        decoder.load_state_dict(state_dict, strict=False, assign=True)
        self.decoder = decoder.to(self.device).eval()

        vocoder = VocoderConfigurator.from_config(config)
        state_dict_obj = self.loader.load(
            self.checkpoint_path,
            sd_ops=VOCODER_COMFY_KEYS_FILTER,
            device=self.device,
        )
        state_dict = state_dict_obj.sd
        if self.dtype is not None:
            state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()}
        vocoder.load_state_dict(state_dict, strict=False, assign=True)
        self.vocoder = vocoder.to(self.device).eval()

        return encoder, decoder, vocoder

    def encode(self, audio_spectrogram: torch.Tensor) -> torch.Tensor:
        if self.cpu_offload:
            self.encoder = self.encoder.to(AI_DEVICE)
        out = self.encoder(audio_spectrogram)
        if self.cpu_offload:
            self.encoder = self.encoder.to("cpu")
        return out

    def decode(self, latent: torch.Tensor) -> torch.Tensor:
        if self.cpu_offload:
            self.decoder = self.decoder.to(AI_DEVICE)
            self.vocoder = self.vocoder.to(AI_DEVICE)
        out = decode_audio(latent, self.decoder, self.vocoder)
        if self.cpu_offload:
            self.decoder = self.decoder.to("cpu")
            self.vocoder = self.vocoder.to("cpu")
        return out


class LTX2Upsampler:
    """
    Wrapper class for loading and using LatentUpsampler model, similar to LTX2VideoVAE.
    """

    def __init__(
        self,
        checkpoint_path: str,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
        cpu_offload: bool = False,
    ):
        self.checkpoint_path = checkpoint_path
        self.device = device
        self.dtype = dtype
        self.cpu_offload = cpu_offload
        self.loader = None
        self.upsampler = None
        self.load()

    def load(self):
        """
        Load upsampler model from checkpoint.

        Aligned exactly with Builder.build() in ltx_core.loader.single_gpu_model_builder:
        1. Create model on meta device (aligned with Builder.meta_model)
        2. Load state_dict and convert dtype if needed (aligned with Builder.build line 82-83)
        3. Load state_dict with assign=True (aligned with Builder.build line 84)
        4. Move to device only (aligned with Builder._return_model line 69)

        Key point: _return_model only does .to(device), NOT .to(dtype).
        This means we rely on assign=True to set correct dtype from state_dict.
        """
        self.loader = SafetensorsModelStateDictLoader()
        config = self.loader.metadata(self.checkpoint_path)

        # Handle config format: may have rational_spatial_scale instead of spatial_scale
        if "rational_spatial_scale" in config and "spatial_scale" not in config:
            config["spatial_scale"] = config["rational_spatial_scale"]

        # Create model on meta device (aligned with Builder.meta_model line 47-48)
        with torch.device("meta"):
            upsampler = LatentUpsamplerConfigurator.from_config(config)

        # Load state_dict (aligned with Builder.load_sd)
        state_dict_obj = self.loader.load(
            self.checkpoint_path,
            sd_ops=None,  # No key filtering, aligned with source code
            device=self.device,  # Directly to target device (aligned with DummyRegistry case)
        )
        state_dict = state_dict_obj.sd

        # Convert state_dict dtype if needed (aligned with Builder.build line 82-83)
        if self.dtype is not None:
            state_dict = {key: value.to(dtype=self.dtype) for key, value in state_dict.items()}

        # Load state_dict with assign=True (aligned with Builder.build line 84)
        # assign=True directly replaces parameters, so dtype should match state_dict
        upsampler.load_state_dict(state_dict, strict=False, assign=True)

        # Move to device only (aligned with Builder._return_model line 69)
        # CRITICAL: _return_model only does .to(device), NOT .to(dtype)
        # This means we rely on assign=True to have set correct dtype from state_dict
        # If state_dict contains all parameters, they should already have correct dtype
        self.upsampler = upsampler.to(self.device).eval()
        return self.upsampler

    @torch.no_grad()
    def upsample(
        self,
        latent: torch.Tensor,
        video_encoder: VideoEncoder,
    ) -> torch.Tensor:
        """
        Upsample video latent using the upsampler with proper normalization.
        Aligned with ltx_core.model.upsampler.model.upsample_video.

        This method directly calls the static upsample_video method to ensure
        exact alignment with source code implementation.

        Args:
            latent: Input latent tensor of shape [B, C, F, H, W] or [C, F, H, W].
            video_encoder: VideoEncoder with per_channel_statistics for normalization.

        Returns:
            Upsampled latent tensor of shape [B, C, F, H*2, W*2] or [C, F, H*2, W*2].
        """

        if self.cpu_offload:
            self.upsampler = self.upsampler.to(AI_DEVICE)
        upsampled = self.upsample_video(latent, video_encoder, self.upsampler)
        if self.cpu_offload:
            self.upsampler = self.upsampler.to("cpu")

        return upsampled

    @staticmethod
    def upsample_video(latent: torch.Tensor, video_encoder: VideoEncoder, upsampler) -> torch.Tensor:
        """
        Apply upsampling to the latent representation using the provided upsampler,
        with normalization and un-normalization based on the video encoder's per-channel statistics.

        This is a static method that can be used with any upsampler instance and video encoder.
        Aligned with ltx_core.model.upsampler.model.upsample_video.

        Args:
            latent: Input latent tensor of shape [B, C, F, H, W].
            video_encoder: VideoEncoder with per_channel_statistics for normalization.
            upsampler: LatentUpsampler module to perform upsampling.
                Note: upsampler should already be in eval mode, on correct device, and with correct dtype.

        Returns:
            torch.Tensor: Upsampled and re-normalized latent tensor.
        """
        # Aligned with source code: un_normalize -> upsampler -> normalize
        # Source code does not modify upsampler state, so we call it directly
        latent = video_encoder.per_channel_statistics.un_normalize(latent)
        latent = upsampler(latent)
        latent = video_encoder.per_channel_statistics.normalize(latent)
        return latent


if __name__ == "__main__":
    dev = "cuda"
    dtype = torch.bfloat16

    video_vae = LTX2VideoVAE(
        checkpoint_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
        device=dev,
        dtype=dtype,
    )

    audio_vae = LTX2AudioVAE(
        checkpoint_path="/data/nvme0/gushiqiao/models/official_models/LTX-2/ltx-2-19b-distilled-fp8.safetensors",
        device=dev,
        dtype=dtype,
    )

    vid_enc = torch.load("/data/nvme0/gushiqiao/models/code/LightX2V/scripts/v.pth").unsqueeze(0)
    vid_dec = video_vae.decode(vid_enc)

    audio_enc = torch.load("/data/nvme0/gushiqiao/models/code/LightX2V/scripts/a.pth").unsqueeze(0)
    audio_dec = audio_vae.decode(audio_enc)

    encode_video(
        video=vid_dec,
        fps=24,
        audio=audio_dec,
        audio_sample_rate=24000,
        output_path=f"reconstructed_1.mp4",
        video_chunks_number=1,
    )