scheduler.py 5.91 KB
Newer Older
sandy's avatar
sandy committed
1
import math
wangshankun's avatar
wangshankun committed
2

PengGao's avatar
PengGao committed
3
4
import numpy as np
import torch
helloyongyang's avatar
helloyongyang committed
5
from loguru import logger
PengGao's avatar
PengGao committed
6

helloyongyang's avatar
helloyongyang committed
7
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
PengGao's avatar
PengGao committed
8
from lightx2v.utils.envs import *
sandy's avatar
sandy committed
9
from lightx2v.utils.utils import masks_like
wangshankun's avatar
wangshankun committed
10
11


12
class EulerScheduler(WanScheduler):
helloyongyang's avatar
helloyongyang committed
13
14
    def __init__(self, config):
        super().__init__(config)
15
16
        d = config["dim"] // config["num_heads"]
        self.rope_t_dim = d // 2 - 2 * (d // 6)
wangshankun's avatar
wangshankun committed
17

18
19
        if self.config["parallel"]:
            self.sp_size = self.config["parallel"].get("seq_p_size", 1)
sandy's avatar
sandy committed
20
21
22
23
24
25
26
        else:
            self.sp_size = 1

        if self.config["model_cls"] == "wan2.2_audio":
            self.prev_latents = None
            self.prev_len = 0

27
28
29
    def set_audio_adapter(self, audio_adapter):
        self.audio_adapter = audio_adapter

wangshankun's avatar
wangshankun committed
30
    def step_pre(self, step_index):
31
        super().step_pre(step_index)
32
33
        if self.audio_adapter.cpu_offload:
            self.audio_adapter.time_embedding.to("cuda")
34
        self.audio_adapter_t_emb = self.audio_adapter.time_embedding(self.timestep_input).unflatten(1, (3, -1))
35
36
        if self.audio_adapter.cpu_offload:
            self.audio_adapter.time_embedding.to("cpu")
wangshankun's avatar
wangshankun committed
37

38
        if self.config["model_cls"] == "wan2.2_audio":
sandy's avatar
sandy committed
39
            _, lat_f, lat_h, lat_w = self.latents.shape
40
41
42
            F = (lat_f - 1) * self.config["vae_stride"][0] + 1
            per_latent_token_len = lat_h * lat_w // (self.config["patch_size"][1] * self.config["patch_size"][2])
            max_seq_len = ((F - 1) // self.config["vae_stride"][0] + 1) * per_latent_token_len
sandy's avatar
sandy committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
            max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size

            temp_ts = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()
            self.timestep_input = torch.cat([temp_ts, temp_ts.new_ones(max_seq_len - temp_ts.size(0)) * self.timestep_input]).unsqueeze(0)

            self.timestep_input = torch.cat(
                [
                    self.timestep_input,
                    torch.zeros(
                        (1, per_latent_token_len),  # padding for reference frame latent
                        dtype=self.timestep_input.dtype,
                        device=self.timestep_input.device,
                    ),
                ],
                dim=1,
            )

60
61
    def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
        self.generator = torch.Generator(device=self.device).manual_seed(seed)
sandy's avatar
sandy committed
62
        self.latents = torch.randn(
63
64
65
66
            latent_shape[0],
            latent_shape[1],
            latent_shape[2],
            latent_shape[3],
sandy's avatar
sandy committed
67
68
69
70
71
72
73
74
75
            dtype=dtype,
            device=self.device,
            generator=self.generator,
        )
        if self.config["model_cls"] == "wan2.2_audio":
            self.mask = masks_like(self.latents, zero=True, prev_len=self.prev_len)
            if self.prev_latents is not None:
                self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents

76
77
    def prepare(self, seed, latent_shape, image_encoder_output=None):
        self.prepare_latents(seed, latent_shape, dtype=torch.float32)
wangshankun's avatar
wangshankun committed
78
        timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
wangshankun's avatar
wangshankun committed
79

wangshankun's avatar
wangshankun committed
80
81
        self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
        self.timesteps_ori = self.timesteps.clone()
wangshankun's avatar
wangshankun committed
82

wangshankun's avatar
wangshankun committed
83
84
        self.sigmas = self.timesteps_ori / self.num_train_timesteps
        self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
wangshankun's avatar
wangshankun committed
85

wangshankun's avatar
wangshankun committed
86
        self.timesteps = self.sigmas * self.num_train_timesteps
wangshankun's avatar
wangshankun committed
87

88
89
90
        self.freqs[latent_shape[1] // self.patch_size[0] :, : self.rope_t_dim] = 0
        self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0] + 1, latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2]))

wangshankun's avatar
wangshankun committed
91
92
93
    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
helloyongyang's avatar
helloyongyang committed
94
95
        sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
        sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
96
        x_t_next = sample + (sigma_next - sigma) * model_output
wangshankun's avatar
wangshankun committed
97
        self.latents = x_t_next
sandy's avatar
sandy committed
98
99
        if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
            self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents
wangshankun's avatar
wangshankun committed
100

101
    def reset(self, seed, latent_shape, image_encoder_output=None):
sandy's avatar
sandy committed
102
        if self.config["model_cls"] == "wan2.2_audio":
103
104
105
            self.prev_latents = image_encoder_output["prev_latents"]
            self.prev_len = image_encoder_output["prev_len"]
        self.prepare_latents(seed, latent_shape, dtype=torch.float32)
wangshankun's avatar
wangshankun committed
106

helloyongyang's avatar
helloyongyang committed
107
108
109
110
111
112
113
    def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
        if in_tensor.ndim > tgt_n_dim:
            logger.warning(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
            return in_tensor
        if in_tensor.ndim < tgt_n_dim:
            in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
        return in_tensor
114
115
116
117
118
119
120
121
122
123
124
125
126


class ConsistencyModelScheduler(EulerScheduler):
    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
        sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
        sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
        x0 = sample - model_output * sigma
        x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
        self.latents = x_t_next
        if self.config["model_cls"] == "wan2.2_audio" and self.prev_latents is not None:
            self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents