scheduler.py 5.62 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import math

import numpy as np
import torch
from loguru import logger

from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.utils import masks_like
from lightx2v_platform.base.global_var import AI_DEVICE


class EulerScheduler(WanScheduler):
    def __init__(self, config):
        super().__init__(config)

        if self.config["parallel"]:
            self.sp_size = self.config["parallel"].get("seq_p_size", 1)
        else:
            self.sp_size = 1

        if self.config["model_cls"] == "wan2.2_audio":
            self.prev_latents = None
            self.prev_len = 0

    def set_audio_adapter(self, audio_adapter):
        self.audio_adapter = audio_adapter

    def step_pre(self, step_index):
        super().step_pre(step_index)
        if self.audio_adapter.cpu_offload:
            self.audio_adapter.time_embedding.to(AI_DEVICE)
        self.audio_adapter_t_emb = self.audio_adapter.time_embedding(self.timestep_input).unflatten(1, (3, -1))
        if self.audio_adapter.cpu_offload:
            self.audio_adapter.time_embedding.to("cpu")

        if self.config["model_cls"] == "wan2.2_audio":
            _, lat_f, lat_h, lat_w = self.latents.shape
            F = (lat_f - 1) * self.config["vae_stride"][0] + 1
            per_latent_token_len = lat_h * lat_w // (self.config["patch_size"][1] * self.config["patch_size"][2])
            max_seq_len = ((F - 1) // self.config["vae_stride"][0] + 1) * per_latent_token_len
            max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size

            temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
            self.timestep_input = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * self.timestep_input]).unsqueeze(0)

            self.timestep_input = torch.cat(
                [
                    self.timestep_input,
                    torch.zeros(
                        (1, per_latent_token_len),  # padding for reference frame latent
                        dtype=self.timestep_input.dtype,
                        device=self.timestep_input.device,
                    ),
                ],
                dim=1,
            )

    def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
        self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed)
        self.latents = torch.randn(
            latent_shape[0],
            latent_shape[1],
            latent_shape[2],
            latent_shape[3],
            dtype=dtype,
            device=AI_DEVICE,
            generator=self.generator,
        )
        if self.config["model_cls"] == "wan2.2_audio":
            self.mask = masks_like(self.latents, zero=True, prev_len=self.prev_len)
            if self.prev_latents is not None:
                self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents

    def prepare(self, seed, latent_shape, image_encoder_output=None):
        self.prepare_latents(seed, latent_shape, dtype=torch.float32)
        timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)

        self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=AI_DEVICE)
        self.timesteps_ori = self.timesteps.clone()

        self.sigmas = self.timesteps_ori / self.num_train_timesteps
        self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)

        self.timesteps = self.sigmas * self.num_train_timesteps

    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
        sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
        sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
        x_t_next = sample + (sigma_next - sigma) * model_output
        self.latents = x_t_next
        if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
            self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents

    def reset(self, seed, latent_shape, image_encoder_output=None):
        if self.config["model_cls"] == "wan2.2_audio":
            self.prev_latents = image_encoder_output["prev_latents"]
            self.prev_len = image_encoder_output["prev_len"]
        self.prepare_latents(seed, latent_shape, dtype=torch.float32)

    def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
        if in_tensor.ndim > tgt_n_dim:
            logger.warning(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
            return in_tensor
        if in_tensor.ndim < tgt_n_dim:
            in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
        return in_tensor


class ConsistencyModelScheduler(EulerScheduler):
    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
        sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
        sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
        x0 = sample - model_output * sigma
        x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
        self.latents = x_t_next
        if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
            self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents