scheduler.py 3.38 KB
Newer Older
1
2
3
4
5
6
7
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler


class WanScheduler4ChangingResolution(WanScheduler):
    def __init__(self, config):
        super().__init__(config)
8
9
10
11
12
        if "resolution_rate" not in config:
            config["resolution_rate"] = [0.75]
        if "changing_resolution_steps" not in config:
            config["changing_resolution_steps"] = [config.infer_steps // 2]
        assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
13
14

    def prepare_latents(self, target_shape, dtype=torch.float32):
15
16
17
18
19
20
21
22
23
24
25
26
27
        self.latents_list = []
        for i in range(len(self.config["resolution_rate"])):
            self.latents_list.append(
                torch.randn(
                    target_shape[0],
                    target_shape[1],
                    int(target_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
                    int(target_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
                    dtype=dtype,
                    device=self.device,
                    generator=self.generator,
                )
            )
helloyongyang's avatar
fix ci  
helloyongyang committed
28

29
30
31
32
33
34
35
36
37
38
39
        # add original resolution latents
        self.latents_list.append(
            torch.randn(
                target_shape[0],
                target_shape[1],
                target_shape[2],
                target_shape[3],
                dtype=dtype,
                device=self.device,
                generator=self.generator,
            )
40
41
        )

42
43
44
45
        # set initial latents
        self.latents = self.latents_list[0]
        self.changing_resolution_index = 0

46
    def step_post(self):
47
        if self.step_index + 1 in self.config["changing_resolution_steps"]:
48
            self.step_post_upsample()
49
            self.changing_resolution_index += 1
50
51
52
53
54
55
56
57
58
59
60
61
62
        else:
            super().step_post()

    def step_post_upsample(self):
        # 1. denoised sample to clean noise
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
        sigma_t = self.sigmas[self.step_index]
        x0_pred = sample - sigma_t * model_output
        denoised_sample = x0_pred.to(sample.dtype)

        # 2. upsample clean noise to target shape
        denoised_sample_5d = denoised_sample.unsqueeze(0)  # (C,T,H,W) -> (1,C,T,H,W)
63
64
65

        shape_to_upsampled = self.latents_list[self.changing_resolution_index + 1].shape[1:]
        clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=shape_to_upsampled, mode="trilinear")
helloyongyang's avatar
fix ci  
helloyongyang committed
66
67
        clean_noise = clean_noise.squeeze(0)  # (1,C,T,H,W) -> (C,T,H,W)

68
        # 3. add noise to clean noise
69
        noisy_sample = self.add_noise(clean_noise, self.latents_list[self.changing_resolution_index + 1], self.timesteps[self.step_index + 1])
helloyongyang's avatar
fix ci  
helloyongyang committed
70

71
72
        # 4. update latents
        self.latents = noisy_sample
helloyongyang's avatar
fix ci  
helloyongyang committed
73

74
        # self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
helloyongyang's avatar
fix ci  
helloyongyang committed
75

76
77
        # 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
        self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + self.changing_resolution_index + 1)
78
79
80
81
82
83

    def add_noise(self, original_samples, noise, timesteps):
        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
        return noisy_samples