scheduler.py 4.64 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from loguru import logger

from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15Scheduler
from lightx2v_platform.base.global_var import AI_DEVICE


class WorldPlayDistillScheduler(HunyuanVideo15Scheduler):
    """
    Flow-match scheduler for WorldPlay distill model.

    Extends HunyuanVideo15Scheduler with:
    - Support for few-step inference (4 steps by default)
    - Autoregressive chunk-based generation
    - Action and camera pose conditioning support
    """

    def __init__(self, config):
        super().__init__(config)
        # Distill-specific parameters - use exact timesteps from HY-WorldPlay
        # These are the specific timesteps for 4-step distill inference
        self.distill_timesteps = [1000.0, 960.0, 888.8889, 727.2728, 0.0]
        self.infer_steps = len(self.distill_timesteps) - 1  # 4 steps (5 timesteps including final 0)

        self.num_train_timesteps = 1000
        self.sigma_max = 1.0
        self.sigma_min = 0.0

        # AR generation parameters
        self.chunk_latent_frames = config.get("chunk_latent_frames", 4)
        self.model_type = config.get("model_type", "ar")

        # Camera/action conditioning
        self.viewmats = None
        self.Ks = None
        self.action = None

        # Per-token vec flag (set by pre_infer when action conditioning is active)
        self.vec_is_per_token = False

    def set_timesteps(self, num_inference_steps, device, shift):
        """Compute distill timestep schedule.

        For distill model, we use the exact timesteps from HY-WorldPlay:
        [1000.0, 960.0, 888.8889, 727.2728, 0.0]

        Sigmas are simply timesteps / 1000 (no shift applied for distill).
        """
        # Use exact timesteps from HY-WorldPlay for distill model
        self.timesteps = torch.tensor(self.distill_timesteps, dtype=torch.float32, device=device)

        # Compute sigmas - for distill model, NO shift is applied
        # sigmas = timesteps / 1000.0 directly
        self.sigmas = (self.timesteps / self.num_train_timesteps).to("cpu")

        logger.info(f"[WorldPlayDistillScheduler] Timesteps: {self.timesteps.tolist()}")
        logger.info(f"[WorldPlayDistillScheduler] Sigmas: {self.sigmas.tolist()}")

    def step_post(self):
        """Euler step for flow matching with distill schedule.

        Matches HY-WorldPlay implementation exactly:
        dt = sigma_next - sigma
        prev_sample = sample + model_output * dt
        """
        flow_pred = self.noise_pred.to(torch.float32)

        # Get current and next sigma
        sigma = self.sigmas[self.step_index].item()
        sigma_next = self.sigmas[self.step_index + 1].item()

        # Euler step: prev_sample = sample + (sigma_next - sigma) * model_output
        dt = sigma_next - sigma

        prev_sample = self.latents.to(torch.float32) + dt * flow_pred

        self.latents = prev_sample.to(self.latents.dtype)

    def prepare(self, seed, latent_shape, image_encoder_output=None, pose_output=None):
        """
        Initialize latents and timesteps with optional pose conditioning.

        Args:
            seed: Random seed for latent initialization
            latent_shape: Shape of latent tensor [C, T, H, W]
            image_encoder_output: Dict with siglip_output, siglip_mask, cond_latents
            pose_output: Dict with viewmats, Ks, action tensors (optional)
        """
        self.prepare_latents(seed, latent_shape, dtype=torch.bfloat16)
        self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift)
        self.multitask_mask = self.get_task_mask(self.config["task"], latent_shape[-3])

        cond_latents = image_encoder_output.get("cond_latents") if image_encoder_output else None
        self.cond_latents_concat, self.mask_concat = self._prepare_cond_latents_and_mask(self.config["task"], cond_latents, self.latents, self.multitask_mask, self.reorg_token)
        self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3]))

        # Store pose conditioning if provided
        if pose_output is not None:
            self.viewmats = pose_output.get("viewmats")
            self.Ks = pose_output.get("Ks")
            self.action = pose_output.get("action")

    def get_chunk_timesteps(self, chunk_idx):
        """Get timesteps for a specific chunk in AR generation."""
        return self.timesteps

    def get_chunk_sigmas(self, chunk_idx):
        """Get sigmas for a specific chunk in AR generation."""
        return self.sigmas

    def clear(self):
        """Cleanup scheduler state."""
        super().clear()
        self.viewmats = None
        self.Ks = None
        self.action = None
        self.vec_is_per_token = False