import torch from diffusers.utils.torch_utils import randn_tensor from typing import Union, Tuple, List from typing import Any, Callable, Dict, List, Optional, Union, Tuple from lightx2v.text2v.models.schedulers.scheduler import BaseScheduler def _to_tuple(x, dim=2): if isinstance(x, int): return (x,) * dim elif len(x) == dim: return x else: raise ValueError(f"Expected length {dim} or int, but got {x}") def get_1d_rotary_pos_embed( dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0, use_real: bool = False, theta_rescale_factor: float = 1.0, interpolation_factor: float = 1.0, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Precompute the frequency tensor for complex exponential (cis) with given dimensions. (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) This function calculates a frequency tensor with complex exponential using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (int): Dimension of the frequency tensor. pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. use_real (bool, optional): If True, return real part and imaginary part separately. Otherwise, return complex numbers. theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. Returns: freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] """ if isinstance(pos, int): pos = torch.arange(pos).float() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # has some connection to NTK literature if theta_rescale_factor != 1.0: theta *= theta_rescale_factor ** (dim / (dim - 2)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] if use_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis def get_meshgrid_nd(start, *args, dim=2): """ Get n-D meshgrid with start, stop and num. Args: start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in n-tuples. *args: See above. dim (int): Dimension of the meshgrid. Defaults to 2. Returns: grid (np.ndarray): [dim, ...] """ if len(args) == 0: # start is grid_size num = _to_tuple(start, dim=dim) start = (0,) * dim stop = num elif len(args) == 1: # start is start, args[0] is stop, step is 1 start = _to_tuple(start, dim=dim) stop = _to_tuple(args[0], dim=dim) num = [stop[i] - start[i] for i in range(dim)] elif len(args) == 2: # start is start, args[0] is stop, args[1] is num start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 else: raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) axis_grid = [] for i in range(dim): a, b, n = start[i], stop[i], num[i] g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] axis_grid.append(g) grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] grid = torch.stack(grid, dim=0) # [dim, W, H, D] return grid def get_nd_rotary_pos_embed( rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor: Union[float, List[float]] = 1.0, interpolation_factor: Union[float, List[float]] = 1.0, ): """ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. Args: rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. sum(rope_dim_list) should equal to head_dim of attention layer. start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. *args: See above. theta (float): Scaling factor for frequency computation. Defaults to 10000.0. use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real part and an imaginary part separately. theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. Returns: pos_embed (torch.Tensor): [HW, D/2] """ grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)" if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): interpolation_factor = [interpolation_factor] * len(rope_dim_list) elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)" # use 1/ndim of dimensions to encode grid_axis embs = [] for i in range(len(rope_dim_list)): emb = get_1d_rotary_pos_embed( rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real, theta_rescale_factor=theta_rescale_factor[i], interpolation_factor=interpolation_factor[i], ) # 2 x [WHD, rope_dim_list[i]] embs.append(emb) if use_real: cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) return cos, sin else: emb = torch.cat(embs, dim=1) # (WHD, D/2) return emb def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000): sigmas = torch.linspace(1, 0, num_inference_steps + 1) sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas) timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.bfloat16, device=device) return timesteps, sigmas class HunyuanScheduler(BaseScheduler): def __init__(self, args): super().__init__(args) self.infer_steps = self.args.infer_steps self.shift = 7.0 self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda")) assert len(self.timesteps) == self.infer_steps self.embedded_guidance_scale = 6.0 self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [42]] self.noise_pred = None self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16) self.prepare_guidance() self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width) def prepare_guidance(self): self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0 def step_post(self): sample = self.latents.to(torch.float32) dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] prev_sample = sample + self.noise_pred.to(torch.float32) * dt self.latents = prev_sample def prepare_latents(self, shape, dtype): self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype) def prepare_rotary_pos_embedding(self, video_length, height, width): target_ndim = 3 ndim = 5 - 2 # 884 vae = "884-16c-hy" patch_size = [1, 2, 2] hidden_size = 3072 heads_num = 24 rope_theta = 256 rope_dim_list = [16, 56, 56] if "884" in vae: latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] elif "888" in vae: latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] else: latents_size = [video_length, height // 8, width // 8] if isinstance(patch_size, int): assert all(s % patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}." rope_sizes = [s // patch_size for s in latents_size] elif isinstance(patch_size, list): assert all(s % patch_size[idx] == 0 for idx, s in enumerate(latents_size)), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}." rope_sizes = [s // patch_size[idx] for idx, s in enumerate(latents_size)] if len(rope_sizes) != target_ndim: rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis head_dim = hidden_size // heads_num rope_dim_list = rope_dim_list if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed( rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1, ) self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda")) self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))