scheduler.py 14.5 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
from typing import List, Optional, Union

helloyongyang's avatar
helloyongyang committed
3
4
import numpy as np
import torch
PengGao's avatar
PengGao committed
5

6
from lightx2v.models.schedulers.scheduler import BaseScheduler
7
from lightx2v.utils.utils import masks_like
helloyongyang's avatar
helloyongyang committed
8
9
10


class WanScheduler(BaseScheduler):
helloyongyang's avatar
helloyongyang committed
11
12
    def __init__(self, config):
        super().__init__(config)
Gu Shiqiao's avatar
Gu Shiqiao committed
13
        self.run_device = torch.device(self.config.get("run_device", "cuda"))
14
15
16
        self.infer_steps = self.config["infer_steps"]
        self.target_video_length = self.config["target_video_length"]
        self.sample_shift = self.config["sample_shift"]
17
        self.patch_size = (1, 2, 2)
helloyongyang's avatar
helloyongyang committed
18
19
20
21
22
        self.shift = 1
        self.num_train_timesteps = 1000
        self.disable_corrector = []
        self.solver_order = 2
        self.noise_pred = None
23
24
        self.sample_guide_scale = self.config["sample_guide_scale"]
        self.caching_records_2 = [True] * self.config["infer_steps"]
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        self.head_size = self.config["dim"] // self.config["num_heads"]
        self.freqs = torch.cat(
            [
                self.rope_params(1024, self.head_size - 4 * (self.head_size // 6)),
                self.rope_params(1024, 2 * (self.head_size // 6)),
                self.rope_params(1024, 2 * (self.head_size // 6)),
            ],
            dim=1,
        ).to(torch.device(self.run_device))

    def rope_params(self, max_seq_len, dim, theta=10000):
        assert dim % 2 == 0
        freqs = torch.outer(
            torch.arange(max_seq_len),
            1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)),
        )
        freqs = torch.polar(torch.ones_like(freqs), freqs)
        return freqs
43

44
45
    def prepare(self, seed, latent_shape, image_encoder_output=None):
        if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
46
            self.vae_encoder_out = image_encoder_output["vae_encoder_out"]
helloyongyang's avatar
helloyongyang committed
47

48
        self.prepare_latents(seed, latent_shape, dtype=torch.float32)
Dongz's avatar
Dongz committed
49
50

        alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
helloyongyang's avatar
helloyongyang committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        sigmas = 1.0 - alphas
        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)

        sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)

        self.sigmas = sigmas
        self.timesteps = sigmas * self.num_train_timesteps

        self.model_outputs = [None] * self.solver_order
        self.timestep_list = [None] * self.solver_order
        self.last_sample = None

        self.sigmas = self.sigmas.to("cpu")
        self.sigma_min = self.sigmas[-1].item()
        self.sigma_max = self.sigmas[0].item()

Gu Shiqiao's avatar
Gu Shiqiao committed
67
        self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift)
helloyongyang's avatar
helloyongyang committed
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2]))

    def prepare_cos_sin(self, grid_sizes):
        c = self.head_size // 2
        freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
        f, h, w = grid_sizes
        seq_len = f * h * w
        cos_sin = torch.cat(
            [
                freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
                freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
                freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
            ],
            dim=-1,
        )
        if self.config.get("rope_type", "flashinfer") == "flashinfer":
            cos_sin = cos_sin.reshape(seq_len, -1)
            # Extract cos and sin parts separately and concatenate
            cos_half = cos_sin.real.contiguous()
            sin_half = cos_sin.imag.contiguous()
            cos_sin = torch.cat([cos_half, sin_half], dim=-1)
        else:
            cos_sin = cos_sin.reshape(seq_len, 1, -1)
        return cos_sin

94
    def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
Gu Shiqiao's avatar
Gu Shiqiao committed
95
        self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
helloyongyang's avatar
helloyongyang committed
96
        self.latents = torch.randn(
97
98
99
100
            latent_shape[0],
            latent_shape[1],
            latent_shape[2],
            latent_shape[3],
helloyongyang's avatar
helloyongyang committed
101
            dtype=dtype,
Gu Shiqiao's avatar
Gu Shiqiao committed
102
            device=self.run_device,
helloyongyang's avatar
helloyongyang committed
103
104
            generator=self.generator,
        )
105
        if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
helloyongyang's avatar
helloyongyang committed
106
107
            self.mask = masks_like(self.latents, zero=True)
            self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents
helloyongyang's avatar
helloyongyang committed
108
109
110
111
112
113
114
115
116

    def set_timesteps(
        self,
        infer_steps: Union[int, None] = None,
        device: Union[str, torch.device] = None,
        sigmas: Optional[List[float]] = None,
        mu: Optional[Union[float, None]] = None,
        shift: Optional[Union[float, None]] = None,
    ):
Dongz's avatar
Dongz committed
117
        sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1]
helloyongyang's avatar
helloyongyang committed
118
119
120
121
122
123
124
125
126
127
128

        if shift is None:
            shift = self.shift
        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)

        sigma_last = 0

        timesteps = sigmas * self.num_train_timesteps
        sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)

        self.sigmas = torch.from_numpy(sigmas)
Dongz's avatar
Dongz committed
129
        self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
helloyongyang's avatar
helloyongyang committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

        assert len(self.timesteps) == self.infer_steps
        self.model_outputs = [
            None,
        ] * self.solver_order
        self.lower_order_nums = 0
        self.last_sample = None
        self._begin_index = None
        self.sigmas = self.sigmas.to("cpu")

    def _sigma_to_alpha_sigma_t(self, sigma):
        return 1 - sigma, sigma

    def convert_model_output(
        self,
        model_output: torch.Tensor,
        *args,
        sample: torch.Tensor = None,
        **kwargs,
    ) -> torch.Tensor:
        timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
                raise ValueError("missing `sample` as a required keyward argument")

        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        sigma_t = self.sigmas[self.step_index]
        x0_pred = sample - sigma_t * model_output
        return x0_pred

163
    def reset(self, seed, latent_shape, step_index=None):
164
165
        if step_index is not None:
            self.step_index = step_index
wangshankun's avatar
wangshankun committed
166
167
168
169
170
171
        self.model_outputs = [None] * self.solver_order
        self.timestep_list = [None] * self.solver_order
        self.last_sample = None
        self.noise_pred = None
        self.this_order = None
        self.lower_order_nums = 0
172
        self.prepare_latents(seed, latent_shape, dtype=torch.float32)
wangshankun's avatar
wangshankun committed
173

helloyongyang's avatar
helloyongyang committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    def multistep_uni_p_bh_update(
        self,
        model_output: torch.Tensor,
        *args,
        sample: torch.Tensor = None,
        order: int = None,
        **kwargs,
    ) -> torch.Tensor:
        prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
        if sample is None:
            if len(args) > 1:
                sample = args[1]
            else:
                raise ValueError(" missing `sample` as a required keyward argument")
        if order is None:
            if len(args) > 2:
                order = args[2]
            else:
                raise ValueError(" missing `order` as a required keyward argument")
        model_output_list = self.model_outputs

        s0 = self.timestep_list[-1]
        m0 = model_output_list[-1]
        x = sample

        sigma_t, sigma_s0 = (
            self.sigmas[self.step_index + 1],
            self.sigmas[self.step_index],
        )
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)

        h = lambda_t - lambda_s0
        device = sample.device

        rks = []
        D1s = []
        for i in range(1, order):
            si = self.step_index - i
            mi = model_output_list[-(i + 1)]
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        B_h = torch.expm1(hh)

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)  # (B, K)
            # for order 2, we use a simplified version
            if order == 2:
                rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
            else:
                rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
        else:
            D1s = None

        x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
        if D1s is not None:
            pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
        else:
            pred_res = 0
        x_t = x_t_ - alpha_t * B_h * pred_res
        x_t = x_t.to(x.dtype)
        return x_t

    def multistep_uni_c_bh_update(
        self,
        this_model_output: torch.Tensor,
        *args,
        last_sample: torch.Tensor = None,
        this_sample: torch.Tensor = None,
        order: int = None,
        **kwargs,
    ) -> torch.Tensor:
        this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
        if last_sample is None:
            if len(args) > 1:
                last_sample = args[1]
            else:
                raise ValueError(" missing`last_sample` as a required keyward argument")
        if this_sample is None:
            if len(args) > 2:
                this_sample = args[2]
            else:
                raise ValueError(" missing`this_sample` as a required keyward argument")
        if order is None:
            if len(args) > 3:
                order = args[3]
            else:
                raise ValueError(" missing`order` as a required keyward argument")

        model_output_list = self.model_outputs

        m0 = model_output_list[-1]
        x = last_sample
        x_t = this_sample
        model_t = this_model_output

        sigma_t, sigma_s0 = (
            self.sigmas[self.step_index],
            self.sigmas[self.step_index - 1],
        )
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
        alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)

        lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
        lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)

        h = lambda_t - lambda_s0
        device = this_sample.device

        rks = []
        D1s = []
        for i in range(1, order):
            si = self.step_index - (i + 1)
            mi = model_output_list[-(i + 1)]
            alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
            lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
            rk = (lambda_si - lambda_s0) / h
            rks.append(rk)
            D1s.append((mi - m0) / rk)

        rks.append(1.0)
        rks = torch.tensor(rks, device=device)

        R = []
        b = []

        hh = -h
        h_phi_1 = torch.expm1(hh)  # h\phi_1(h) = e^h - 1
        h_phi_k = h_phi_1 / hh - 1

        factorial_i = 1

        B_h = torch.expm1(hh)

        for i in range(1, order + 1):
            R.append(torch.pow(rks, i - 1))
            b.append(h_phi_k * factorial_i / B_h)
            factorial_i *= i + 1
            h_phi_k = h_phi_k / hh - 1 / factorial_i

        R = torch.stack(R)
        b = torch.tensor(b, device=device)

        if len(D1s) > 0:
            D1s = torch.stack(D1s, dim=1)
        else:
            D1s = None

        # for order 1, we use a simplified version
        if order == 1:
            rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
        else:
            rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)

        x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
        if D1s is not None:
            corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
        else:
            corr_res = 0
        D1_t = model_t - m0
        x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
        x_t = x_t.to(x.dtype)
        return x_t

366
367
368
    def step_pre(self, step_index):
        super().step_pre(step_index)
        self.timestep_input = torch.stack([self.timesteps[self.step_index]])
369
        if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
370
371
            self.timestep_input = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten()

helloyongyang's avatar
helloyongyang committed
372
373
374
375
376
    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        timestep = self.timesteps[self.step_index]
        sample = self.latents.to(torch.float32)

Dongz's avatar
Dongz committed
377
        use_corrector = self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
helloyongyang's avatar
helloyongyang committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

        model_output_convert = self.convert_model_output(model_output, sample=sample)
        if use_corrector:
            sample = self.multistep_uni_c_bh_update(
                this_model_output=model_output_convert,
                last_sample=self.last_sample,
                this_sample=sample,
                order=self.this_order,
            )

        for i in range(self.solver_order - 1):
            self.model_outputs[i] = self.model_outputs[i + 1]
            self.timestep_list[i] = self.timestep_list[i + 1]

        self.model_outputs[-1] = model_output_convert
        self.timestep_list[-1] = timestep

Dongz's avatar
Dongz committed
395
        this_order = min(self.solver_order, len(self.timesteps) - self.step_index)
helloyongyang's avatar
helloyongyang committed
396

Dongz's avatar
Dongz committed
397
        self.this_order = min(this_order, self.lower_order_nums + 1)  # warmup for multistep
helloyongyang's avatar
helloyongyang committed
398
399
400
401
402
403
404
405
406
407
408
409
410
        assert self.this_order > 0

        self.last_sample = sample
        prev_sample = self.multistep_uni_p_bh_update(
            model_output=model_output,
            sample=sample,
            order=self.this_order,
        )

        if self.lower_order_nums < self.solver_order:
            self.lower_order_nums += 1

        self.latents = prev_sample
411
        if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
412
            self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents