scheduler.py 2.65 KB
Newer Older
1
2
3
4
5
6
7
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler


class WanScheduler4ChangingResolution(WanScheduler):
    def __init__(self, config):
        super().__init__(config)
helloyongyang's avatar
helloyongyang committed
8
9
        self.resolution_rate = config.get("resolution_rate", 0.75)
        self.changing_resolution_steps = config.get("changing_resolution_steps", config.infer_steps // 2)
10
11
12
13
14
15
16
17
18
19
20

    def prepare_latents(self, target_shape, dtype=torch.float32):
        self.latents = torch.randn(
            target_shape[0],
            target_shape[1],
            int(target_shape[2] * self.resolution_rate) // 2 * 2,
            int(target_shape[3] * self.resolution_rate) // 2 * 2,
            dtype=dtype,
            device=self.device,
            generator=self.generator,
        )
helloyongyang's avatar
fix ci  
helloyongyang committed
21

22
23
24
25
26
27
28
29
30
31
32
        self.noise_original_resolution = torch.randn(
            target_shape[0],
            target_shape[1],
            target_shape[2],
            target_shape[3],
            dtype=dtype,
            device=self.device,
            generator=self.generator,
        )

    def step_post(self):
33
        if self.step_index == self.changing_resolution_steps - 1:
34
35
36
37
38
39
40
41
42
43
44
45
46
47
            self.step_post_upsample()
        else:
            super().step_post()

    def step_post_upsample(self):
        # 1. denoised sample to clean noise
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
        sigma_t = self.sigmas[self.step_index]
        x0_pred = sample - sigma_t * model_output
        denoised_sample = x0_pred.to(sample.dtype)

        # 2. upsample clean noise to target shape
        denoised_sample_5d = denoised_sample.unsqueeze(0)  # (C,T,H,W) -> (1,C,T,H,W)
helloyongyang's avatar
fix ci  
helloyongyang committed
48
49
50
        clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=(self.config.target_shape[1], self.config.target_shape[2], self.config.target_shape[3]), mode="trilinear")
        clean_noise = clean_noise.squeeze(0)  # (1,C,T,H,W) -> (C,T,H,W)

51
52
        # 3. add noise to clean noise
        noisy_sample = self.add_noise(clean_noise, self.noise_original_resolution, self.timesteps[self.step_index + 1])
helloyongyang's avatar
fix ci  
helloyongyang committed
53

54
55
        # 4. update latents
        self.latents = noisy_sample
helloyongyang's avatar
fix ci  
helloyongyang committed
56

57
        # self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
helloyongyang's avatar
fix ci  
helloyongyang committed
58

59
60
61
62
63
64
65
66
        # 5. update timesteps using shift + 2 更激进的去噪
        self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + 2)

    def add_noise(self, original_samples, noise, timesteps):
        sigma = self.sigmas[self.step_index]
        alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
        noisy_samples = alpha_t * original_samples + sigma_t * noise
        return noisy_samples