scheduler.py 2.47 KB
Newer Older
wangshankun's avatar
wangshankun committed
1
2
import gc

PengGao's avatar
PengGao committed
3
4
import numpy as np
import torch
helloyongyang's avatar
helloyongyang committed
5
from loguru import logger
PengGao's avatar
PengGao committed
6

helloyongyang's avatar
helloyongyang committed
7
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
PengGao's avatar
PengGao committed
8
from lightx2v.utils.envs import *
wangshankun's avatar
wangshankun committed
9
10


helloyongyang's avatar
helloyongyang committed
11
12
13
class ConsistencyModelScheduler(WanScheduler):
    def __init__(self, config):
        super().__init__(config)
wangshankun's avatar
wangshankun committed
14

15
16
17
    def set_audio_adapter(self, audio_adapter):
        self.audio_adapter = audio_adapter

wangshankun's avatar
wangshankun committed
18
    def step_pre(self, step_index):
19
        super().step_pre(step_index)
20
21
        if self.audio_adapter.cpu_offload:
            self.audio_adapter.time_embedding.to("cuda")
22
        self.audio_adapter_t_emb = self.audio_adapter.time_embedding(self.timestep_input).unflatten(1, (3, -1))
23
24
        if self.audio_adapter.cpu_offload:
            self.audio_adapter.time_embedding.to("cpu")
wangshankun's avatar
wangshankun committed
25
26
27

    def prepare(self, image_encoder_output=None):
        self.prepare_latents(self.config.target_shape, dtype=torch.float32)
wangshankun's avatar
wangshankun committed
28
        timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
wangshankun's avatar
wangshankun committed
29

wangshankun's avatar
wangshankun committed
30
31
        self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
        self.timesteps_ori = self.timesteps.clone()
wangshankun's avatar
wangshankun committed
32

wangshankun's avatar
wangshankun committed
33
34
        self.sigmas = self.timesteps_ori / self.num_train_timesteps
        self.sigmas = self.sample_shift * self.sigmas / (1 + (self.sample_shift - 1) * self.sigmas)
wangshankun's avatar
wangshankun committed
35

wangshankun's avatar
wangshankun committed
36
        self.timesteps = self.sigmas * self.num_train_timesteps
wangshankun's avatar
wangshankun committed
37
38
39
40

    def step_post(self):
        model_output = self.noise_pred.to(torch.float32)
        sample = self.latents.to(torch.float32)
helloyongyang's avatar
helloyongyang committed
41
42
43
44
        sigma = self.unsqueeze_to_ndim(self.sigmas[self.step_index], sample.ndim).to(sample.device, sample.dtype)
        sigma_next = self.unsqueeze_to_ndim(self.sigmas[self.step_index + 1], sample.ndim).to(sample.device, sample.dtype)
        x0 = sample - model_output * sigma
        x_t_next = x0 * (1 - sigma_next) + sigma_next * torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, generator=self.generator)
wangshankun's avatar
wangshankun committed
45
46
47
48
49
50
        self.latents = x_t_next

    def reset(self):
        self.prepare_latents(self.config.target_shape, dtype=torch.float32)
        gc.collect()
        torch.cuda.empty_cache()
wangshankun's avatar
wangshankun committed
51

helloyongyang's avatar
helloyongyang committed
52
53
54
55
56
57
58
    def unsqueeze_to_ndim(self, in_tensor, tgt_n_dim):
        if in_tensor.ndim > tgt_n_dim:
            logger.warning(f"the given tensor of shape {in_tensor.shape} is expected to unsqueeze to {tgt_n_dim}, the original tensor will be returned")
            return in_tensor
        if in_tensor.ndim < tgt_n_dim:
            in_tensor = in_tensor[(...,) + (None,) * (tgt_n_dim - in_tensor.ndim)]
        return in_tensor