scheduler.py 4.25 KB
Newer Older
wangshankun's avatar
wangshankun committed
1
2
import gc
import math
PengGao's avatar
PengGao committed
3
import os
wangshankun's avatar
wangshankun committed
4
5
from typing import List, Optional, Tuple, Union

PengGao's avatar
PengGao committed
6
7
import numpy as np
import torch
wangshankun's avatar
wangshankun committed
8
9
10
from diffusers import (
    FlowMatchEulerDiscreteScheduler as FlowMatchEulerDiscreteSchedulerBase,  # pyright: ignore
)
PengGao's avatar
PengGao committed
11
12
13
14
15
16
from diffusers.configuration_utils import register_to_config
from loguru import logger
from torch import Tensor

from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import *
wangshankun's avatar
wangshankun committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


def unsqueeze_to_ndim(in_tensor: Tensor, tgt_n_dim: int):
    if in_tensor.ndim > tgt_n_dim:
        warnings.warn(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
        return in_tensor
    if in_tensor.ndim < tgt_n_dim:
        in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
    return in_tensor


class EulerSchedulerTimestepFix(BaseScheduler):
    def __init__(self, config, **kwargs):
        # super().__init__(**kwargs)
        self.init_noise_sigma = 1.0
        self.config = config
        self.latents = None
        self.device = torch.device("cuda")
        self.infer_steps = self.config.infer_steps
        self.target_video_length = self.config.target_video_length
        self.sample_shift = self.config.sample_shift
        self.num_train_timesteps = 1000
        self.step_index = None

    def step_pre(self, step_index):
        self.step_index = step_index
        if GET_DTYPE() == "BF16":
            self.latents = self.latents.to(dtype=torch.bfloat16)

    def prepare(self, image_encoder_output=None):
        self.prepare_latents(self.config.target_shape, dtype=torch.float32)

        if self.config.task in ["t2v"]:
            self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
        elif self.config.task in ["i2v"]:
            self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])

wangshankun's avatar
wangshankun committed
54
        timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
wangshankun's avatar
wangshankun committed
55

wangshankun's avatar
wangshankun committed
56
57
        self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
        self.timesteps_ori = self.timesteps.clone()
wangshankun's avatar
wangshankun committed
58

wangshankun's avatar
wangshankun committed
59
60
        self.sigmas = self.timesteps_ori / self.num_train_timesteps
        self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
wangshankun's avatar
wangshankun committed
61

wangshankun's avatar
wangshankun committed
62
        self.timesteps = self.sigmas * self.num_train_timesteps
wangshankun's avatar
wangshankun committed
63
64

    def prepare_latents(self, target_shape, dtype=torch.float32):
wangshankun's avatar
wangshankun committed
65
        self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
wangshankun's avatar
wangshankun committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        self.latents = (
            torch.randn(
                target_shape[0],
                target_shape[1],
                target_shape[2],
                target_shape[3],
                dtype=dtype,
                device=self.device,
                generator=self.generator,
            )
            * self.init_noise_sigma
        )

    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)

        sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
        sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
        x_t_next = sample + (sigma_next - sigma) * model_output

        self.latents = x_t_next

    def reset(self):
        self.prepare_latents(self.config.target_shape, dtype=torch.float32)
        gc.collect()
        torch.cuda.empty_cache()
wangshankun's avatar
wangshankun committed
93
94
95
96
97
98
99
100
101


class ConsistencyModelScheduler(EulerSchedulerTimestepFix):
    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
        sigma = unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
        sigma_next = unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
        x0 = sample - model_output * sigma
102
        x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
wangshankun's avatar
wangshankun committed
103
        self.latents = x_t_next