scheduler.py 5.42 KB
Newer Older
1
2
3
4
import torch

from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
5
from lightx2v_platform.base.global_var import AI_DEVICE
6
7
8
9
10
11


class WanSFScheduler(WanScheduler):
    def __init__(self, config):
        super().__init__(config)
        self.dtype = torch.bfloat16
12
13
        self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"]
        self.num_output_frames = self.config["sf_config"]["num_output_frames"]
14
        self.num_blocks = self.num_output_frames // self.num_frame_per_block
15
16
        self.denoising_step_list = self.config["sf_config"]["denoising_step_list"]
        self.infer_steps = len(self.denoising_step_list)
17
18
19
20
21
        self.all_num_frames = [self.num_frame_per_block] * self.num_blocks
        self.num_input_frames = 0
        self.denoising_strength = 1.0
        self.sigma_max = 1.0
        self.sigma_min = 0
22
        self.sf_shift = self.config["sf_config"]["shift"]
23
24
25
        self.inverse_timesteps = False
        self.extra_one_step = True
        self.reverse_sigmas = False
26
        self.num_inference_steps = self.config["sf_config"]["num_inference_steps"]
27
28
        self.context_noise = 0

29
    def prepare(self, seed, latent_shape, image_encoder_output=None):
30
        self.latents = torch.randn(latent_shape, device=AI_DEVICE, dtype=self.dtype)
31
32
33
34
35
36

        timesteps = []
        for frame_block_idx, current_num_frames in enumerate(self.all_num_frames):
            frame_steps = []

            for step_index, current_timestep in enumerate(self.denoising_step_list):
37
                timestep = torch.ones([self.num_frame_per_block], device=AI_DEVICE, dtype=torch.int64) * current_timestep
38
39
40
41
42
                frame_steps.append(timestep)

            timesteps.append(frame_steps)
        self.timesteps = timesteps

43
        self.noise_pred = torch.zeros(latent_shape, device=AI_DEVICE, dtype=self.dtype)
44
45
46
47
48
49
50
51
52
53
54

        sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength
        if self.extra_one_step:
            self.sigmas_sf = torch.linspace(sigma_start, self.sigma_min, self.num_inference_steps + 1)[:-1]
        else:
            self.sigmas_sf = torch.linspace(sigma_start, self.sigma_min, self.num_inference_steps)
        if self.inverse_timesteps:
            self.sigmas_sf = torch.flip(self.sigmas_sf, dims=[0])
        self.sigmas_sf = self.sf_shift * self.sigmas_sf / (1 + (self.sf_shift - 1) * self.sigmas_sf)
        if self.reverse_sigmas:
            self.sigmas_sf = 1 - self.sigmas_sf
55
        self.sigmas_sf = self.sigmas_sf.to(AI_DEVICE)
56
57

        self.timesteps_sf = self.sigmas_sf * self.num_train_timesteps
58
        self.timesteps_sf = self.timesteps_sf.to(AI_DEVICE)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

        self.stream_output = None

    def step_pre(self, seg_index, step_index, is_rerun=False):
        self.step_index = step_index
        self.seg_index = seg_index

        if not GET_DTYPE() == GET_SENSITIVE_DTYPE():
            self.latents = self.latents.to(GET_DTYPE())

        if not is_rerun:
            self.timestep_input = torch.stack([self.timesteps[self.seg_index][self.step_index]])
            self.latents_input = self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)]
        else:
            # rerun with timestep zero to update KV cache using clean context
            self.timestep_input = torch.ones_like(torch.stack([self.timesteps[self.seg_index][self.step_index]])) * self.context_noise
            self.latents_input = self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)]

    def step_post(self):
        # convert model outputs
        current_start_frame = self.seg_index * self.num_frame_per_block
        current_end_frame = (self.seg_index + 1) * self.num_frame_per_block

        flow_pred = self.noise_pred[:, current_start_frame:current_end_frame].transpose(0, 1)
        xt = self.latents_input.transpose(0, 1)
        timestep = self.timestep_input.squeeze(0)

        original_dtype = flow_pred.dtype

        flow_pred, xt, sigmas, timesteps = map(lambda x: x.double().to(flow_pred.device), [flow_pred, xt, self.sigmas_sf, self.timesteps_sf])
        timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
        sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
        x0_pred = xt - sigma_t * flow_pred
        x0_pred = x0_pred.to(original_dtype)

        # add noise
95
        if self.step_index < self.infer_steps - 1:
96
            timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=AI_DEVICE, dtype=torch.long)
97
98
99
100
101
102
103
104
105
            timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1)
            sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1)
            noise_next = torch.randn_like(x0_pred)
            sample_next = (1 - sigma_next) * x0_pred + sigma_next * noise_next
            sample_next = sample_next.type_as(noise_next)
            self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] = sample_next.transpose(0, 1)
        else:
            self.latents[:, self.seg_index * self.num_frame_per_block : min((self.seg_index + 1) * self.num_frame_per_block, self.num_output_frames)] = x0_pred.transpose(0, 1)
            self.stream_output = x0_pred.transpose(0, 1)