scheduler.py 28.6 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
"""
LTX2 Scheduler for LightX2V.

This scheduler integrates LTX-2's diffusion sampling logic including:
- Sigma schedule generation with token-count-dependent shifting
- Euler diffusion stepping
- Classifier-free guidance (CFG)
- Audio-video joint denoising
"""

import math
from dataclasses import dataclass
from typing import Optional, Tuple

import einops
import torch

from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE


def get_pixel_coords(
    latent_coords: torch.Tensor,
    scale_factors: Tuple[int, int, int],
    causal_fix: bool = False,
) -> torch.Tensor:
    """
    Map latent-space [start, end) coordinates to pixel-space equivalents.
    Args:
        latent_coords: Tensor of latent bounds shaped [3, num_patches, 2]
        scale_factors: Tuple (temporal, height, width) with integer scale factors
        causal_fix: When True, rewrites temporal axis of first frame
    Returns:
        Pixel coordinates with shape [3, num_patches, 2]
    """
    scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(-1, 1, 1)
    pixel_coords = latent_coords * scale_tensor

    if causal_fix:
        pixel_coords[0, ...] = (pixel_coords[0, ...] + 1 - scale_factors[0]).clamp(min=0)

    return pixel_coords


@dataclass
class LatentState:
    """
    State of latents during the diffusion denoising process.

    Attributes:
        latent: The current noisy latent tensor being denoised (patchified, shape: [T, D])
        denoise_mask: Mask encoding the denoising strength for each token (patchified, shape: [T, 1])
        positions: Positional indices for each latent element (shape: [3, T] for video, [2, T] for audio)
        clean_latent: Initial state of the latent before denoising (patchified, shape: [T, D])
    """

    latent: torch.Tensor
    denoise_mask: torch.Tensor
    positions: torch.Tensor
    clean_latent: torch.Tensor


class VideoLatentPatchifier:
    """Video latent patchifier (adapted from ltx_core, without batch dimension)."""

    def __init__(self, patch_size: int):
        self._patch_size = (
            1,  # temporal dimension
            patch_size,  # height dimension
            patch_size,  # width dimension
        )

    @property
    def patch_size(self) -> Tuple[int, int, int]:
        return self._patch_size

    def patchify(self, latents: torch.Tensor) -> torch.Tensor:
        """Patchify video latents from [C, F, H, W] to [T, D]."""
        return einops.rearrange(
            latents,
            "c (f p1) (h p2) (w p3) -> (f h w) (c p1 p2 p3)",
            p1=self._patch_size[0],
            p2=self._patch_size[1],
            p3=self._patch_size[2],
        )

    def unpatchify(self, latents: torch.Tensor, frames: int, height: int, width: int) -> torch.Tensor:
        """Unpatchify video latents from [T, D] to [C, F, H, W]."""
        patch_grid_frames = frames // self._patch_size[0]
        patch_grid_height = height // self._patch_size[1]
        patch_grid_width = width // self._patch_size[2]

        return einops.rearrange(
            latents,
            "(f h w) (c p1 p2 p3) -> c (f p1) (h p2) (w p3)",
            f=patch_grid_frames,
            h=patch_grid_height,
            w=patch_grid_width,
            p1=self._patch_size[0],
            p2=self._patch_size[1],
            p3=self._patch_size[2],
        )

    def get_patch_grid_bounds(
        self,
        frames: int,
        height: int,
        width: int,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Return the per-dimension bounds [inclusive start, exclusive end) for every patch.
        Returns shape [3, num_patches, 2] where 3 is (frame, height, width) and 2 is [start, end).
        """
        grid_coords = torch.meshgrid(
            torch.arange(start=0, end=frames, step=self._patch_size[0], device=device),
            torch.arange(start=0, end=height, step=self._patch_size[1], device=device),
            torch.arange(start=0, end=width, step=self._patch_size[2], device=device),
            indexing="ij",
        )

        patch_starts = torch.stack(grid_coords, dim=0)

        patch_size_delta = torch.tensor(
            self._patch_size,
            device=patch_starts.device,
            dtype=patch_starts.dtype,
        ).view(3, 1, 1, 1)

        patch_ends = patch_starts + patch_size_delta

        latent_coords = torch.stack((patch_starts, patch_ends), dim=-1)

        # Flatten to [3, num_patches, 2]
        latent_coords = einops.rearrange(
            latent_coords,
            "c f h w bounds -> c (f h w) bounds",
            bounds=2,
        )

        return latent_coords


class AudioPatchifier:
    """Audio latent patchifier (adapted from ltx_core, without batch dimension)."""

    def __init__(
        self,
        patch_size: int,
        sample_rate: int = 16000,
        hop_length: int = 160,
        audio_latent_downsample_factor: int = 4,
        is_causal: bool = True,
        shift: int = 0,
    ):
        self.hop_length = hop_length
        self.sample_rate = sample_rate
        self.audio_latent_downsample_factor = audio_latent_downsample_factor
        self.is_causal = is_causal
        self.shift = shift
        self._patch_size = (1, patch_size, patch_size)

    @property
    def patch_size(self) -> Tuple[int, int, int]:
        return self._patch_size

    def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
        """Patchify audio latents from [C, T, F] to [T, D]."""
        if len(audio_latents.shape) == 4:
            return einops.rearrange(
                audio_latents,
                "b c t f -> b t (c f)",
            )
        return einops.rearrange(
            audio_latents,
            "c t f -> t (c f)",
        )

    def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
        """Unpatchify audio latents from [T, D] to [C, T, F]."""
        if len(audio_latents.shape) == 3:
            return einops.rearrange(
                audio_latents,
                "b t (c f) ->b c t f",
                c=channels,
                f=mel_bins,
            )
        return einops.rearrange(
            audio_latents,
            "t (c f) -> c t f",
            c=channels,
            f=mel_bins,
        )

    def _get_audio_latent_time_in_sec(
        self,
        start_latent: int,
        end_latent: int,
        dtype: torch.dtype,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """Convert latent indices into real-time seconds."""
        if device is None:
            device = torch.device("cpu")

        audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
        audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor

        if self.is_causal:
            causal_offset = 1
            audio_mel_frame = (audio_mel_frame + causal_offset - self.audio_latent_downsample_factor).clip(min=0)

        return audio_mel_frame * self.hop_length / self.sample_rate

    def _compute_audio_timings(
        self,
        batch_size: int,
        num_steps: int,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Builds a [B, 1, T, 2] tensor containing timestamps for each latent frame.
        This helper method underpins get_patch_grid_bounds for the audio patchifier.
        Args:
            batch_size: Number of sequences to broadcast the timings over (we use 1 for no batch).
            num_steps: Number of latent frames (time steps) to convert into timestamps.
            device: Device on which the resulting tensor should reside.
        """
        resolved_device = device
        if resolved_device is None:
            resolved_device = torch.device("cpu")

        start_timings = self._get_audio_latent_time_in_sec(
            self.shift,
            num_steps + self.shift,
            torch.float32,
            resolved_device,
        )
        start_timings = start_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        end_timings = self._get_audio_latent_time_in_sec(
            self.shift + 1,
            num_steps + self.shift + 1,
            torch.float32,
            resolved_device,
        )
        end_timings = end_timings.unsqueeze(0).expand(batch_size, -1).unsqueeze(1)
        return torch.stack([start_timings, end_timings], dim=-1)

    def get_patch_grid_bounds(
        self,
        frames: int,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """
        Return the temporal bounds [inclusive start, exclusive end) for every patch.
        Returns shape [1, time_steps, 2] where 1 is temporal dimension and 2 is [start, end) in seconds.
        """
        # Use batch_size=1 since we removed batch dimension
        return self._compute_audio_timings(batch_size=1, num_steps=frames, device=device)


class LTX2Scheduler(BaseScheduler):
    """
    Scheduler for LTX-2 diffusion model.

    Handles sigma schedule generation, Euler stepping, and CFG guidance
    for joint audio-video denoising.
    """

    def __init__(self, config):
        """
        Initialize LTX2 scheduler.

        Args:
            config: Configuration dictionary containing:
                - infer_steps: Number of inference steps
                - cfg_guidance_scale: CFG guidance scale (default: 1.0)
                - scheduler_type: Type of sigma scheduler (default: "ltx2")
                - max_shift: Maximum shift for LTX2 scheduler (default: 2.05)
                - base_shift: Base shift for LTX2 scheduler (default: 0.95)
                - stretch: Whether to stretch sigmas (default: True)
                - terminal: Terminal sigma value (default: 0.1)
        """
        super().__init__(config)
        if config.get("distilled_sigma_values", None) is not None:
            self.sigmas = torch.tensor(config["distilled_sigma_values"])
            self.infer_steps = len(self.sigmas) - 1
        else:
            self.sigmas = None
            self.infer_steps = config["infer_steps"]
        self.cfg_guidance_scale = config.get("cfg_guidance_scale", 1.0)

        # Sigma scheduler configuration
        self.max_shift = config["sample_shift"][0]
        self.base_shift = config["sample_shift"][1]
        self.stretch = config.get("stretch", True)
        self.terminal = config.get("terminal", 0.1)

        # Constants for sigma scheduling
        self.base_shift_anchor = 1024
        self.max_shift_anchor = 4096

        # Patchifier configuration
        self.video_patch_size = config.get("video_patch_size", 1)
        self.fps = config.get("fps", 24)  # Frames per second for position calculation
        self.video_scale_factors = config.get("video_scale_factors", (8, 32, 32))  # (time, height, width)

        # Initialize patchifiers
        self.video_patchifier = VideoLatentPatchifier(patch_size=self.video_patch_size)
        self.audio_patchifier = AudioPatchifier(
            patch_size=config.get("audio_patch_size", 1),
            sample_rate=config.get("audio_sampling_rate", 16000),
            hop_length=config.get("audio_hop_length", 160),
            audio_latent_downsample_factor=config.get("audio_scale_factor", 4),
            is_causal=config.get("audio_is_causal", True),
            shift=config.get("audio_shift", 0),
        )

        # Diffusion components
        # self.stepper = EulerDiffusionStep()
        self.sample_guide_scale = self.config["sample_guide_scale"]

        # State
        self.video_latent_state = None
        self.audio_latent_state = None

        # Step state
        self.v_noise_pred = None
        self.a_noise_pred = None
        self.keep_latents_dtype_in_scheduler = config.get("keep_latents_dtype_in_scheduler", False)

    def step_pre(self, step_index):
        self.step_index = step_index

    def prepare(
        self,
        seed: int,
        video_latent_shape: tuple,
        audio_latent_shape: tuple,
        initial_video_latent: Optional[torch.Tensor] = None,
        initial_audio_latent: Optional[torch.Tensor] = None,
        noise_scale: float = 1.0,
        video_denoise_mask: Optional[torch.Tensor] = None,
        audio_denoise_mask: Optional[torch.Tensor] = None,
    ):
        """
        Prepare scheduler for inference.

        Args:
            seed: Random seed for noise generation
            video_latent_shape: Shape of video latents
            audio_latent_shape: Shape of audio latents
            initial_video_latent: Optional initial video latent (for conditioning)
            initial_audio_latent: Opti onal initial audio latent (for conditioning)
            noise_scale: Scale factor for noise
            video_denoise_mask: Optional denoise mask for video (unpatchified)
            audio_denoise_mask: Optional denoise mask for audio (unpatchified)
        """
        # Reset step state (important for stage 2 after stage 1)
        self.step_index = 0
        self.v_noise_pred = None
        self.a_noise_pred = None

        # Initialize generator
        self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed)

        # Prepare latents
        self.prepare_latents(
            video_latent_shape=video_latent_shape,
            audio_latent_shape=audio_latent_shape,
            initial_video_latent=initial_video_latent,
            initial_audio_latent=initial_audio_latent,
            noise_scale=noise_scale,
            video_denoise_mask=video_denoise_mask,
            audio_denoise_mask=audio_denoise_mask,
        )

        if self.sigmas is None:
            self.set_timesteps(infer_steps=self.infer_steps)

    def prepare_latents(
        self,
        video_latent_shape: tuple,
        audio_latent_shape: tuple,
        initial_video_latent: Optional[torch.Tensor] = None,
        initial_audio_latent: Optional[torch.Tensor] = None,
        noise_scale: float = 1.0,
        video_denoise_mask: Optional[torch.Tensor] = None,
        audio_denoise_mask: Optional[torch.Tensor] = None,
    ):
        """
        Prepare initial latents for denoising and patchify them.

        This method follows LTX2's create_initial_state logic:
        1. Create unpatchified latents [C, F, H, W] for video, [C, F, M] for audio
        2. Add noise following GaussianNoiser formula
        3. Patchify to [T, D] format
        4. Create positions and denoise_mask in patchified format
        5. Return LatentState objects

        Args:
            video_latent_shape: Shape of video latents (C, F, H, W) - batch dimension removed
            audio_latent_shape: Shape of audio latents (C, F, M) - batch dimension removed
            initial_video_latent: Optional initial video latent (for conditioning)
            initial_audio_latent: Optional initial audio latent (for conditioning)
            noise_scale: Scale factor for noise
            video_denoise_mask: Optional denoise mask for video (unpatchified)
            audio_denoise_mask: Optional denoise mask for audio (unpatchified)
            dtype: Data type for latents (defaults to GET_DTYPE())
        """

        # Prepare video latents
        self._prepare_video_latents(
            video_latent_shape=video_latent_shape,
            initial_video_latent=initial_video_latent,
            noise_scale=noise_scale,
            video_denoise_mask=video_denoise_mask,
            dtype=GET_DTYPE(),
        )

        # Prepare audio latents
        self._prepare_audio_latents(
            audio_latent_shape=audio_latent_shape,
            initial_audio_latent=initial_audio_latent,
            noise_scale=noise_scale,
            audio_denoise_mask=audio_denoise_mask,
            dtype=GET_DTYPE(),
        )

    def _prepare_video_latents(
        self,
        video_latent_shape: tuple,
        initial_video_latent: Optional[torch.Tensor] = None,
        noise_scale: float = 1.0,
        video_denoise_mask: Optional[torch.Tensor] = None,
        dtype: torch.dtype = None,
    ):
        """
        Prepare video latents for denoising.

        Args:
            video_latent_shape: Shape of video latents (C, F, H, W) - batch dimension removed
            initial_video_latent: Optional initial video latent (for conditioning)
            noise_scale: Scale factor for noise
            video_denoise_mask: Optional denoise mask for video (unpatchified)
            dtype: Data type for latents
        """
        _, frames_v, height_v, width_v = video_latent_shape

        # Save shape information for unpatchify
        self.video_latent_shape_orig = video_latent_shape

        # Initialize video latents (unpatchified)
        if initial_video_latent is not None:
            video_latent = initial_video_latent.to(dtype=dtype, device=AI_DEVICE)
        else:
            video_latent = torch.zeros(
                *video_latent_shape,
                dtype=dtype,
                device=AI_DEVICE,
            )

        clean_video_latent = video_latent.clone()

        # Create denoise mask (unpatchified)
        if video_denoise_mask is None:
            video_denoise_mask = torch.ones(
                1,
                frames_v,
                height_v,
                width_v,
                dtype=torch.float32,
                device=AI_DEVICE,
            )
        else:
            video_denoise_mask = video_denoise_mask.to(dtype=torch.float32, device=AI_DEVICE)
            if video_denoise_mask.shape[0] != 1:
                video_denoise_mask = video_denoise_mask[:1, ...]  # Take first channel

        # Patchify video latents first (aligned with source code: create_initial_state -> patchify)
        patchified_video_latent = self.video_patchifier.patchify(video_latent)
        patchified_clean_video_latent = self.video_patchifier.patchify(clean_video_latent)
        patchified_video_mask = self.video_patchifier.patchify(video_denoise_mask)

        # Process denoise_mask: ensure float32 and reduce to [T, 1] (aligned with source code)
        patchified_video_mask = patchified_video_mask.to(torch.float32)
        # Reduce patch dimension to 1 (aligned with source code behavior)
        if patchified_video_mask.shape[-1] > 1:
            # Take mean across patch dimension to get [T, 1]
            patchified_video_mask = patchified_video_mask.mean(dim=-1, keepdim=True)
        # Ensure shape is [T, 1]
        if patchified_video_mask.ndim == 1:
            patchified_video_mask = patchified_video_mask.unsqueeze(-1)

        # Add noise after patchify (aligned with source code: GaussianNoiser operates on patchified latent)
        noise_video = torch.randn(
            *patchified_video_latent.shape,
            dtype=patchified_video_latent.dtype,
            device=AI_DEVICE,
            generator=self.generator,
        )

        scaled_mask_video = patchified_video_mask * noise_scale
        patchified_video_latent = (noise_video * scaled_mask_video + patchified_video_latent * (1 - scaled_mask_video)).to(patchified_video_latent.dtype)

        # Get positions for video
        latent_coords_video = self.video_patchifier.get_patch_grid_bounds(frames_v, height_v, width_v, AI_DEVICE)
        positions_video = get_pixel_coords(latent_coords_video, self.video_scale_factors, causal_fix=True)
        # Convert to float first, then divide by fps, then convert to dtype (aligned with source code)
        positions_video = positions_video.float()
        positions_video[0, ...] = positions_video[0, ...] / self.fps
        positions_video = positions_video.to(dtype)

        # Create video LatentState
        self.video_latent_state = LatentState(
            latent=patchified_video_latent,
            denoise_mask=patchified_video_mask,
            positions=positions_video,
            clean_latent=patchified_clean_video_latent,
        )

    def _prepare_audio_latents(
        self,
        audio_latent_shape: tuple,
        initial_audio_latent: Optional[torch.Tensor] = None,
        noise_scale: float = 1.0,
        audio_denoise_mask: Optional[torch.Tensor] = None,
        dtype: torch.dtype = None,
    ):
        """
        Prepare audio latents for denoising.

        Args:
            audio_latent_shape: Shape of audio latents (C, F, M) - batch dimension removed
            initial_audio_latent: Optional initial audio latent (for conditioning)
            noise_scale: Scale factor for noise
            audio_denoise_mask: Optional denoise mask for audio (unpatchified)
            dtype: Data type for latents
        """
        _, frames_a, mel_bins_a = audio_latent_shape

        # Save shape information for unpatchify
        self.audio_latent_shape_orig = audio_latent_shape

        # Initialize audio latents (unpatchified)
        if initial_audio_latent is not None:
            audio_latent = initial_audio_latent.to(dtype=dtype, device=AI_DEVICE)
        else:
            audio_latent = torch.zeros(
                *audio_latent_shape,
                dtype=dtype,
                device=AI_DEVICE,
            )

        clean_audio_latent = audio_latent.clone()

        # Create denoise mask (unpatchified)
        if audio_denoise_mask is None:
            audio_denoise_mask = torch.ones(
                1,
                frames_a,
                mel_bins_a,
                dtype=torch.float32,
                device=AI_DEVICE,
            )
        else:
            audio_denoise_mask = audio_denoise_mask.to(dtype=torch.float32, device=AI_DEVICE)
            if audio_denoise_mask.shape[0] != 1:
                audio_denoise_mask = audio_denoise_mask[:1, ...]

        # Patchify audio latents first (aligned with source code: create_initial_state -> patchify)
        patchified_audio_latent = self.audio_patchifier.patchify(audio_latent)
        patchified_clean_audio_latent = self.audio_patchifier.patchify(clean_audio_latent)
        patchified_audio_mask = self.audio_patchifier.patchify(audio_denoise_mask)

        # Process denoise_mask: ensure float32 and reduce to [T, 1] (aligned with source code)
        patchified_audio_mask = patchified_audio_mask.to(torch.float32)
        # Reduce patch dimension to 1 (aligned with source code behavior)
        if patchified_audio_mask.shape[-1] > 1:
            # Take mean across patch dimension to get [T, 1]
            patchified_audio_mask = patchified_audio_mask.mean(dim=-1, keepdim=True)
        # Ensure shape is [T, 1]
        if patchified_audio_mask.ndim == 1:
            patchified_audio_mask = patchified_audio_mask.unsqueeze(-1)

        # Add noise after patchify (aligned with source code: GaussianNoiser operates on patchified latent)
        noise_audio = torch.randn(
            *patchified_audio_latent.shape,
            dtype=patchified_audio_latent.dtype,
            device=AI_DEVICE,
            generator=self.generator,
        )
        scaled_mask_audio = patchified_audio_mask * noise_scale
        patchified_audio_latent = (noise_audio * scaled_mask_audio + patchified_audio_latent * (1 - scaled_mask_audio)).to(patchified_audio_latent.dtype)

        # Get positions for audio (just time coordinates)
        positions_audio = self.audio_patchifier.get_patch_grid_bounds(frames_a, AI_DEVICE)

        # Create audio LatentState
        self.audio_latent_state = LatentState(
            latent=patchified_audio_latent,
            denoise_mask=patchified_audio_mask,
            positions=positions_audio,
            clean_latent=patchified_clean_audio_latent,
        )

    def set_timesteps(
        self,
        infer_steps: Optional[int] = None,
        latent: Optional[torch.Tensor] = None,
    ):
        """
        Set timesteps and generate sigma schedule.

        This method generates a sigma schedule with token-count-dependent shifting,
        similar to WAN scheduler's set_timesteps method.

        Args:
            infer_steps: Number of inference steps (defaults to self.infer_steps)
            device: Device for timesteps tensor
            latent: Optional latent tensor for token count calculation
        """
        if infer_steps is None:
            infer_steps = self.infer_steps

        # Calculate token count for shift adjustment
        if latent is not None:
            tokens = math.prod(latent.shape[1:])
        else:
            tokens = self.max_shift_anchor

        # Generate linear sigma schedule (aligned with source code)
        sigmas = torch.linspace(1.0, 0.0, infer_steps + 1)

        # Apply token-count-dependent shift
        x1 = self.base_shift_anchor
        x2 = self.max_shift_anchor
        mm = (self.max_shift - self.base_shift) / (x2 - x1)
        b = self.base_shift - mm * x1
        sigma_shift = (tokens) * mm + b

        # Apply shift transformation (aligned with source code)
        power = 1
        sigmas = torch.where(
            sigmas != 0,
            math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
            0,
        )

        if self.stretch:
            non_zero_mask = sigmas != 0
            non_zero_sigmas = sigmas[non_zero_mask]
            one_minus_z = 1.0 - non_zero_sigmas
            scale_factor = one_minus_z[-1] / (1.0 - self.terminal)
            stretched = 1.0 - (one_minus_z / scale_factor)
            sigmas[non_zero_mask] = stretched

        self.sigmas = sigmas.to(torch.float32).to(AI_DEVICE)

    def post_process_latent(self, denoised: torch.Tensor, denoise_mask: torch.Tensor, clean: torch.Tensor) -> torch.Tensor:
        """Blend denoised output with clean state based on mask."""
        return (denoised * denoise_mask + clean.float() * (1 - denoise_mask)).to(denoised.dtype)

    def to_velocity(
        self,
        sample: torch.Tensor,
        sigma: float | torch.Tensor,
        denoised_sample: torch.Tensor,
        calc_dtype: torch.dtype = torch.float32,
    ) -> torch.Tensor:
        """
        Convert the sample and its denoised version to velocity.
        Returns:
            Velocity
        """
        if isinstance(sigma, torch.Tensor):
            sigma = sigma.to(calc_dtype).item()
        if sigma == 0:
            raise ValueError("Sigma can't be 0.0")
        return ((sample.to(calc_dtype) - denoised_sample.to(calc_dtype)) / sigma).to(sample.dtype)

    def step_post(self):
        self.v_noise_pred = self.post_process_latent(self.v_noise_pred, self.video_latent_state.denoise_mask, self.video_latent_state.clean_latent)
        self.a_noise_pred = self.post_process_latent(self.a_noise_pred, self.audio_latent_state.denoise_mask, self.audio_latent_state.clean_latent)

        sigma = self.sigmas[self.step_index]
        sigma_next = self.sigmas[self.step_index + 1]
        dt = sigma_next - sigma

        v_velocity = self.to_velocity(self.video_latent_state.latent, sigma, self.v_noise_pred)
        a_velocity = self.to_velocity(self.audio_latent_state.latent, sigma, self.a_noise_pred)

        v_latent = self.video_latent_state.latent.to(torch.float32) + v_velocity.to(torch.float32) * dt
        a_latent = self.audio_latent_state.latent.to(torch.float32) + a_velocity.to(torch.float32) * dt
        self.video_latent_state.latent = v_latent.to(self.video_latent_state.latent.dtype)
        self.audio_latent_state.latent = a_latent.to(self.audio_latent_state.latent.dtype)

        # Unpatchify latents on the final step (aligned with source code)
        if self.step_index == self.infer_steps - 1:
            channels_v, frames_v, height_v, width_v = self.video_latent_shape_orig
            channels_a, frames_a, mel_bins_a = self.audio_latent_shape_orig

            self.video_latent_state.latent = self.video_patchifier.unpatchify(self.video_latent_state.latent, frames_v, height_v, width_v)
            self.audio_latent_state.latent = self.audio_patchifier.unpatchify(self.audio_latent_state.latent, channels=channels_a, mel_bins=mel_bins_a)

    def clear(self):
        """Clear scheduler state."""
        self.audio_latents = None
        self.video_latent_state = None
        self.audio_latent_state = None
        self.v_noise_pred = None
        self.a_noise_pred = None
        self.sigmas = None

    def video_timesteps_from_mask(self) -> torch.Tensor:
        """Compute timesteps from a denoise mask and sigma value.
        Multiplies the denoise mask by sigma to produce timesteps for each position
        in the latent state. Areas where the mask is 0 will have zero timesteps.
        """
        return self.video_latent_state.denoise_mask * self.sigmas[self.step_index]

    def audio_timesteps_from_mask(self) -> torch.Tensor:
        """Compute timesteps from a denoise mask and sigma value.
        Multiplies the denoise mask by sigma to produce timesteps for each position
        in the latent state. Areas where the mask is 0 will have zero timesteps.
        """
        return self.audio_latent_state.denoise_mask * self.sigmas[self.step_index]

    def reset_sigmas(self, sigmas: torch.Tensor):
        self.sigmas = sigmas.to(torch.float32).to(AI_DEVICE)
        self.infer_steps = len(sigmas) - 1

    def reset_latents(self, video_latent: torch.Tensor):
        self.video_latent_state.latent = video_latent