scheduler.py 3.7 KB
Newer Older
1
import torch
PengGao's avatar
PengGao committed
2

3
4
5
from lightx2v.models.schedulers.wan.scheduler import WanScheduler


6
7
8
9
10
11
12
13
14
15
16
class WanScheduler4ChangingResolutionInterface:
    def __new__(cls, father_scheduler, config):
        class NewClass(WanScheduler4ChangingResolution, father_scheduler):
            def __init__(self, config):
                father_scheduler.__init__(self, config)
                WanScheduler4ChangingResolution.__init__(self, config)

        return NewClass(config)


class WanScheduler4ChangingResolution:
17
    def __init__(self, config):
18
19
20
21
22
        if "resolution_rate" not in config:
            config["resolution_rate"] = [0.75]
        if "changing_resolution_steps" not in config:
            config["changing_resolution_steps"] = [config.infer_steps // 2]
        assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
23
24

    def prepare_latents(self, target_shape, dtype=torch.float32):
25
26
27
28
29
30
31
32
33
34
35
36
37
        self.latents_list = []
        for i in range(len(self.config["resolution_rate"])):
            self.latents_list.append(
                torch.randn(
                    target_shape[0],
                    target_shape[1],
                    int(target_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
                    int(target_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
                    dtype=dtype,
                    device=self.device,
                    generator=self.generator,
                )
            )
helloyongyang's avatar
fix ci  
helloyongyang committed
38

39
40
41
42
43
44
45
46
47
48
49
        # add original resolution latents
        self.latents_list.append(
            torch.randn(
                target_shape[0],
                target_shape[1],
                target_shape[2],
                target_shape[3],
                dtype=dtype,
                device=self.device,
                generator=self.generator,
            )
50
51
        )

52
53
54
55
        # set initial latents
        self.latents = self.latents_list[0]
        self.changing_resolution_index = 0

56
    def step_post(self):
57
        if self.step_index + 1 in self.config["changing_resolution_steps"]:
58
            self.step_post_upsample()
59
            self.changing_resolution_index += 1
60
61
62
63
64
65
66
67
68
69
70
71
72
        else:
            super().step_post()

    def step_post_upsample(self):
        # 1. denoised sample to clean noise
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
        sigma_t = self.sigmas[self.step_index]
        x0_pred = sample - sigma_t * model_output
        denoised_sample = x0_pred.to(sample.dtype)

        # 2. upsample clean noise to target shape
        denoised_sample_5d = denoised_sample.unsqueeze(0)  # (C,T,H,W) -> (1,C,T,H,W)
73
74
75

        shape_to_upsampled = self.latents_list[self.changing_resolution_index + 1].shape[1:]
        clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=shape_to_upsampled, mode="trilinear")
helloyongyang's avatar
fix ci  
helloyongyang committed
76
77
        clean_noise = clean_noise.squeeze(0)  # (1,C,T,H,W) -> (C,T,H,W)

78
        # 3. add noise to clean noise
79
        noisy_sample = self.add_noise(clean_noise, self.latents_list[self.changing_resolution_index + 1], self.timesteps[self.step_index + 1])
helloyongyang's avatar
fix ci  
helloyongyang committed
80

81
82
        # 4. update latents
        self.latents = noisy_sample
helloyongyang's avatar
fix ci  
helloyongyang committed
83

84
        # self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
helloyongyang's avatar
fix ci  
helloyongyang committed
85

86
87
        # 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
        self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + self.changing_resolution_index + 1)
88
89
90
91
92
93

    def add_noise(self, original_samples, noise, timesteps):
        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
        return noisy_samples