from typing import List, Optional, Union import numpy as np import torch from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v.utils.utils import masks_like class WanScheduler(BaseScheduler): def __init__(self, config): super().__init__(config) self.device = torch.device(self.config.get("run_device", "cuda")) self.infer_steps = self.config["infer_steps"] self.target_video_length = self.config["target_video_length"] self.sample_shift = self.config["sample_shift"] self.run_device = self.config.get("run_device", "cuda") self.patch_size = (1, 2, 2) self.shift = 1 self.num_train_timesteps = 1000 self.disable_corrector = [] self.solver_order = 2 self.noise_pred = None self.sample_guide_scale = self.config["sample_guide_scale"] self.caching_records_2 = [True] * self.config["infer_steps"] self.head_size = self.config["dim"] // self.config["num_heads"] self.freqs = torch.cat( [ self.rope_params(1024, self.head_size - 4 * (self.head_size // 6)), self.rope_params(1024, 2 * (self.head_size // 6)), self.rope_params(1024, 2 * (self.head_size // 6)), ], dim=1, ).to(torch.device(self.run_device)) def rope_params(self, max_seq_len, dim, theta=10000): assert dim % 2 == 0 freqs = torch.outer( torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)), ) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs def prepare(self, seed, latent_shape, image_encoder_output=None): if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: self.vae_encoder_out = image_encoder_output["vae_encoder_out"] self.prepare_latents(seed, latent_shape, dtype=torch.float32) alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() sigmas = 1.0 - alphas sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) self.sigmas = sigmas self.timesteps = sigmas * self.num_train_timesteps self.model_outputs = [None] * self.solver_order self.timestep_list = [None] * self.solver_order self.last_sample = None self.sigmas = self.sigmas.to("cpu") self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift) self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2])) def prepare_cos_sin(self, grid_sizes): c = self.head_size // 2 freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) f, h, w = grid_sizes seq_len = f * h * w cos_sin = torch.cat( [ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), ], dim=-1, ) if self.config.get("rope_type", "flashinfer") == "flashinfer": cos_sin = cos_sin.reshape(seq_len, -1) # Extract cos and sin parts separately and concatenate cos_half = cos_sin.real.contiguous() sin_half = cos_sin.imag.contiguous() cos_sin = torch.cat([cos_half, sin_half], dim=-1) else: cos_sin = cos_sin.reshape(seq_len, 1, -1) return cos_sin def prepare_latents(self, seed, latent_shape, dtype=torch.float32): self.generator = torch.Generator(device=self.device).manual_seed(seed) self.latents = torch.randn( latent_shape[0], latent_shape[1], latent_shape[2], latent_shape[3], dtype=dtype, device=self.device, generator=self.generator, ) if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: self.mask = masks_like(self.latents, zero=True) self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents def set_timesteps( self, infer_steps: Union[int, None] = None, device: Union[str, torch.device] = None, sigmas: Optional[List[float]] = None, mu: Optional[Union[float, None]] = None, shift: Optional[Union[float, None]] = None, ): sigmas = np.linspace(self.sigma_max, self.sigma_min, infer_steps + 1).copy()[:-1] if shift is None: shift = self.shift sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) sigma_last = 0 timesteps = sigmas * self.num_train_timesteps sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) assert len(self.timesteps) == self.infer_steps self.model_outputs = [ None, ] * self.solver_order self.lower_order_nums = 0 self.last_sample = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") def _sigma_to_alpha_sigma_t(self, sigma): return 1 - sigma, sigma def convert_model_output( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, **kwargs, ) -> torch.Tensor: timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError("missing `sample` as a required keyward argument") sigma = self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) sigma_t = self.sigmas[self.step_index] x0_pred = sample - sigma_t * model_output return x0_pred def reset(self, seed, latent_shape, step_index=None): if step_index is not None: self.step_index = step_index self.model_outputs = [None] * self.solver_order self.timestep_list = [None] * self.solver_order self.last_sample = None self.noise_pred = None self.this_order = None self.lower_order_nums = 0 self.prepare_latents(seed, latent_shape, dtype=torch.float32) def multistep_uni_p_bh_update( self, model_output: torch.Tensor, *args, sample: torch.Tensor = None, order: int = None, **kwargs, ) -> torch.Tensor: prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) if sample is None: if len(args) > 1: sample = args[1] else: raise ValueError(" missing `sample` as a required keyward argument") if order is None: if len(args) > 2: order = args[2] else: raise ValueError(" missing `order` as a required keyward argument") model_output_list = self.model_outputs s0 = self.timestep_list[-1] m0 = model_output_list[-1] x = sample sigma_t, sigma_s0 = ( self.sigmas[self.step_index + 1], self.sigmas[self.step_index], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - i mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 B_h = torch.expm1(hh) for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: D1s = None x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res x_t = x_t.to(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: torch.Tensor, *args, last_sample: torch.Tensor = None, this_sample: torch.Tensor = None, order: int = None, **kwargs, ) -> torch.Tensor: this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) if last_sample is None: if len(args) > 1: last_sample = args[1] else: raise ValueError(" missing`last_sample` as a required keyward argument") if this_sample is None: if len(args) > 2: this_sample = args[2] else: raise ValueError(" missing`this_sample` as a required keyward argument") if order is None: if len(args) > 3: order = args[3] else: raise ValueError(" missing`order` as a required keyward argument") model_output_list = self.model_outputs m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output sigma_t, sigma_s0 = ( self.sigmas[self.step_index], self.sigmas[self.step_index - 1], ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) h = lambda_t - lambda_s0 device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = self.step_index - (i + 1) mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 B_h = torch.expm1(hh) for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) x_t = x_t.to(x.dtype) return x_t def step_pre(self, step_index): super().step_pre(step_index) self.timestep_input = torch.stack([self.timesteps[self.step_index]]) if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: self.timestep_input = (self.mask[0][:, ::2, ::2] * self.timestep_input).flatten() def step_post(self): model_output = self.noise_pred.to(torch.float32) timestep = self.timesteps[self.step_index] sample = self.latents.to(torch.float32) use_corrector = self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None model_output_convert = self.convert_model_output(model_output, sample=sample) if use_corrector: sample = self.multistep_uni_c_bh_update( this_model_output=model_output_convert, last_sample=self.last_sample, this_sample=sample, order=self.this_order, ) for i in range(self.solver_order - 1): self.model_outputs[i] = self.model_outputs[i + 1] self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert self.timestep_list[-1] = timestep this_order = min(self.solver_order, len(self.timesteps) - self.step_index) self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep assert self.this_order > 0 self.last_sample = sample prev_sample = self.multistep_uni_p_bh_update( model_output=model_output, sample=sample, order=self.this_order, ) if self.lower_order_nums < self.solver_order: self.lower_order_nums += 1 self.latents = prev_sample if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: self.latents = (1.0 - self.mask) * self.vae_encoder_out + self.mask * self.latents