scheduler.py 3.78 KB
Newer Older
1
import math
PengGao's avatar
PengGao committed
2
from typing import Union
PengGao's avatar
PengGao committed
3

4
import torch
PengGao's avatar
PengGao committed
5

6
7
8
from lightx2v.models.schedulers.wan.scheduler import WanScheduler


9
class WanStepDistillScheduler(WanScheduler):
10
11
12
    def __init__(self, config):
        super().__init__(config)
        self.denoising_step_list = config.denoising_step_list
Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
13
        self.infer_steps = len(self.denoising_step_list)
wangshankun's avatar
wangshankun committed
14
        self.sample_shift = self.config.sample_shift
15

16
17
18
19
        self.num_train_timesteps = 1000
        self.sigma_max = 1.0
        self.sigma_min = 0.0

20
21
22
23
24
25
26
27
28
    def prepare(self, image_encoder_output):
        self.generator = torch.Generator(device=self.device)
        self.generator.manual_seed(self.config.seed)

        self.prepare_latents(self.config.target_shape, dtype=torch.float32)

        if self.config.task in ["t2v"]:
            self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
        elif self.config.task in ["i2v"]:
wangshankun's avatar
wangshankun committed
29
            self.seq_len = self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1]
30

Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
31
        self.set_denoising_timesteps(device=self.device)
32
33

    def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
34
35
36
37
        sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min)
        self.sigmas = torch.linspace(sigma_start, self.sigma_min, self.num_train_timesteps + 1)[:-1]
        self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
        self.timesteps = self.sigmas * self.num_train_timesteps
38

39
40
41
        self.denoising_step_index = [self.num_train_timesteps - x for x in self.denoising_step_list]
        self.timesteps = self.timesteps[self.denoising_step_index].to(device)
        self.sigmas = self.sigmas[self.denoising_step_index].to("cpu")
42
43
44

    def reset(self):
        self.prepare_latents(self.config.target_shape, dtype=torch.float32)
45
46
47
48
49
50
51
52
53
54
55

    def add_noise(self, original_samples, noise, sigma):
        sample = (1 - sigma) * original_samples + sigma * noise
        return sample.type_as(noise)

    def step_post(self):
        flow_pred = self.noise_pred.to(torch.float32)
        sigma = self.sigmas[self.step_index].item()
        noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
        if self.step_index < self.infer_steps - 1:
            sigma = self.sigmas[self.step_index + 1].item()
56
57
            noise = torch.randn(noisy_image_or_video.shape, dtype=torch.float32, device=self.device, generator=self.generator)
            noisy_image_or_video = self.add_noise(noisy_image_or_video, noise=noise, sigma=self.sigmas[self.step_index + 1].item())
58
        self.latents = noisy_image_or_video.to(self.latents.dtype)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79


class Wan22StepDistillScheduler(WanStepDistillScheduler):
    def __init__(self, config):
        super().__init__(config)
        self.boundary_step_index = config.boundary_step_index

    def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
        super().set_denoising_timesteps(device)
        self.sigma_boundary = self.sigmas[self.boundary_step_index].item()

    def step_post(self):
        flow_pred = self.noise_pred.to(torch.float32)
        sigma = self.sigmas[self.step_index].item()
        noisy_image_or_video = self.latents.to(torch.float32) - sigma * flow_pred
        if self.step_index < self.boundary_step_index:
            noisy_image_or_video = noisy_image_or_video / self.sigma_boundary
        if self.step_index < self.infer_steps - 1:
            sigma = self.sigmas[self.step_index + 1].item()
            noisy_image_or_video = self.add_noise(noisy_image_or_video, torch.randn_like(noisy_image_or_video), self.sigmas[self.step_index + 1].item())
        self.latents = noisy_image_or_video.to(self.latents.dtype)