scheduler.py 16.7 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

helloyongyang's avatar
helloyongyang committed
3
import numpy as np
PengGao's avatar
PengGao committed
4
import torch
helloyongyang's avatar
helloyongyang committed
5
from diffusers.utils.torch_utils import randn_tensor
PengGao's avatar
PengGao committed
6

7
from lightx2v.models.schedulers.scheduler import BaseScheduler
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17


def _to_tuple(x, dim=2):
    if isinstance(x, int):
        return (x,) * dim
    elif len(x) == dim:
        return x
    else:
        raise ValueError(f"Expected length {dim} or int, but got {x}")

Dongz's avatar
Dongz committed
18

helloyongyang's avatar
helloyongyang committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def get_1d_rotary_pos_embed(
    dim: int,
    pos: Union[torch.FloatTensor, int],
    theta: float = 10000.0,
    use_real: bool = False,
    theta_rescale_factor: float = 1.0,
    interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Precompute the frequency tensor for complex exponential (cis) with given dimensions.
    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)

    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool, optional): If True, return real part and imaginary part separately.
                                   Otherwise, return complex numbers.
        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.

    Returns:
        freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
    """
    if isinstance(pos, int):
        pos = torch.arange(pos).float()

    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
    # has some connection to NTK literature
    if theta_rescale_factor != 1.0:
        theta *= theta_rescale_factor ** (dim / (dim - 2))

Dongz's avatar
Dongz committed
55
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  # [D/2]
helloyongyang's avatar
helloyongyang committed
56
57
58
59
60
61
62
    # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]
    if use_real:
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]
        return freqs_cos, freqs_sin
    else:
Dongz's avatar
Dongz committed
63
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
helloyongyang's avatar
helloyongyang committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        return freqs_cis


def get_meshgrid_nd(start, *args, dim=2):
    """
    Get n-D meshgrid with start, stop and num.

    Args:
        start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
            step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
            should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
            n-tuples.
        *args: See above.
        dim (int): Dimension of the meshgrid. Defaults to 2.

    Returns:
        grid (np.ndarray): [dim, ...]
    """
    if len(args) == 0:
        # start is grid_size
        num = _to_tuple(start, dim=dim)
        start = (0,) * dim
        stop = num
    elif len(args) == 1:
        # start is start, args[0] is stop, step is 1
        start = _to_tuple(start, dim=dim)
        stop = _to_tuple(args[0], dim=dim)
        num = [stop[i] - start[i] for i in range(dim)]
    elif len(args) == 2:
        # start is start, args[0] is stop, args[1] is num
        start = _to_tuple(start, dim=dim)  # Left-Top       eg: 12,0
        stop = _to_tuple(args[0], dim=dim)  # Right-Bottom   eg: 20,32
        num = _to_tuple(args[1], dim=dim)  # Target Size    eg: 32,124
    else:
        raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")

    # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
    axis_grid = []
    for i in range(dim):
        a, b, n = start[i], stop[i], num[i]
        g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
        axis_grid.append(g)
    grid = torch.meshgrid(*axis_grid, indexing="ij")  # dim x [W, H, D]
    grid = torch.stack(grid, dim=0)  # [dim, W, H, D]

    return grid

Dongz's avatar
Dongz committed
111

helloyongyang's avatar
helloyongyang committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def get_nd_rotary_pos_embed(
    rope_dim_list,
    start,
    *args,
    theta=10000.0,
    use_real=False,
    theta_rescale_factor: Union[float, List[float]] = 1.0,
    interpolation_factor: Union[float, List[float]] = 1.0,
):
    """
    This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.

    Args:
        rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
            sum(rope_dim_list) should equal to head_dim of attention layer.
        start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
            args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
        *args: See above.
        theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
            Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
            part and an imaginary part separately.
        theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.

    Returns:
        pos_embed (torch.Tensor): [HW, D/2]
    """

Dongz's avatar
Dongz committed
140
    grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))  # [3, W, H, D] / [2, W, H]
helloyongyang's avatar
helloyongyang committed
141
142
143
144
145

    if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
    elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
Dongz's avatar
Dongz committed
146
    assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
helloyongyang's avatar
helloyongyang committed
147
148
149
150
151

    if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
    elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
Dongz's avatar
Dongz committed
152
    assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
helloyongyang's avatar
helloyongyang committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

    # use 1/ndim of dimensions to encode grid_axis
    embs = []
    for i in range(len(rope_dim_list)):
        emb = get_1d_rotary_pos_embed(
            rope_dim_list[i],
            grid[i].reshape(-1),
            theta,
            use_real=use_real,
            theta_rescale_factor=theta_rescale_factor[i],
            interpolation_factor=interpolation_factor[i],
        )  # 2 x [WHD, rope_dim_list[i]]
        embs.append(emb)

    if use_real:
        cos = torch.cat([emb[0] for emb in embs], dim=1)  # (WHD, D/2)
        sin = torch.cat([emb[1] for emb in embs], dim=1)  # (WHD, D/2)
        return cos, sin
    else:
        emb = torch.cat(embs, dim=1)  # (WHD, D/2)
        return emb


Dongz's avatar
Dongz committed
176
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
helloyongyang's avatar
helloyongyang committed
177
178
    sigmas = torch.linspace(1, 0, num_inference_steps + 1)
    sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
helloyongyang's avatar
helloyongyang committed
179
    timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32, device=device)
helloyongyang's avatar
helloyongyang committed
180
181
182
    return timesteps, sigmas


helloyongyang's avatar
helloyongyang committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def get_1d_rotary_pos_embed_riflex(
    dim: int,
    pos: Union[np.ndarray, int],
    theta: float = 10000.0,
    use_real=False,
    k: Optional[int] = None,
    L_test: Optional[int] = None,
):
    """
    RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
    index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
    data type.

    Args:
        dim (`int`): Dimension of the frequency tensor.
        pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
        theta (`float`, *optional*, defaults to 10000.0):
            Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (`bool`, *optional*):
            If True, return real part and imaginary part separately. Otherwise, return complex numbers.
        k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
        L_test (`int`, *optional*, defaults to None): the number of frames for inference
    Returns:
        `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
    """
    assert dim % 2 == 0

    if isinstance(pos, int):
        pos = torch.arange(pos)
    if isinstance(pos, np.ndarray):
Watebear's avatar
Watebear committed
215
        pos = torch.from_numpy(pos)  # [S]
helloyongyang's avatar
helloyongyang committed
216
217
218
219
220
221
222
223
224
225
226

    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim))  # [D/2]

    # === Riflex modification start ===
    # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
    # Empirical observations show that a few videos may exhibit repetition in the tail frames.
    # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
    if k is not None:
        freqs[k - 1] = 0.9 * 2 * torch.pi / L_test
    # === Riflex modification end ===

Watebear's avatar
Watebear committed
227
    freqs = torch.outer(pos, freqs)  # [S, D/2]
helloyongyang's avatar
helloyongyang committed
228
229
230
231
232
233
234
235
236
237
    if use_real:
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()  # [S, D]
        return freqs_cos, freqs_sin
    else:
        # lumina
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
        return freqs_cis


helloyongyang's avatar
helloyongyang committed
238
class HunyuanScheduler(BaseScheduler):
helloyongyang's avatar
helloyongyang committed
239
240
    def __init__(self, config):
        super().__init__(config)
helloyongyang's avatar
helloyongyang committed
241
        self.shift = 7.0
Dongz's avatar
Dongz committed
242
        self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
243
244
        assert len(self.timesteps) == self.infer_steps
        self.embedded_guidance_scale = 6.0
helloyongyang's avatar
helloyongyang committed
245
        self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [self.config.seed]]
helloyongyang's avatar
helloyongyang committed
246
        self.noise_pred = None
helloyongyang's avatar
helloyongyang committed
247
248
249
250

    def prepare(self, image_encoder_output):
        self.image_encoder_output = image_encoder_output
        self.prepare_latents(shape=self.config.target_shape, dtype=torch.float16, image_encoder_output=image_encoder_output)
helloyongyang's avatar
helloyongyang committed
251
        self.prepare_guidance()
helloyongyang's avatar
helloyongyang committed
252
        self.prepare_rotary_pos_embedding(video_length=self.config.target_video_length, height=self.config.target_height, width=self.config.target_width)
helloyongyang's avatar
helloyongyang committed
253
254

    def prepare_guidance(self):
Dongz's avatar
Dongz committed
255
        self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
helloyongyang's avatar
helloyongyang committed
256
257

    def step_post(self):
helloyongyang's avatar
helloyongyang committed
258
        if self.config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
259
260
261
262
263
264
265
266
            sample = self.latents.to(torch.float32)
            dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
            self.latents = sample + self.noise_pred.to(torch.float32) * dt
        else:
            sample = self.latents[:, :, 1:, :, :].to(torch.float32)
            dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
            latents = sample + self.noise_pred[:, :, 1:, :, :].to(torch.float32) * dt
            self.latents = torch.concat([self.image_encoder_output["img_latents"], latents], dim=2)
helloyongyang's avatar
helloyongyang committed
267

helloyongyang's avatar
helloyongyang committed
268
269
    def prepare_latents(self, shape, dtype, image_encoder_output):
        if self.config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
270
271
            self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
        else:
helloyongyang's avatar
helloyongyang committed
272
            x1 = image_encoder_output["img_latents"].repeat(1, 1, (self.config.target_video_length - 1) // 4 + 1, 1, 1)
helloyongyang's avatar
helloyongyang committed
273
274
275
276
            x0 = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
            t = torch.tensor([0.999]).to(device=torch.device("cuda"))
            self.latents = x0 * t + x1 * (1 - t)
            self.latents = self.latents.to(dtype=dtype)
helloyongyang's avatar
helloyongyang committed
277
            self.latents = torch.concat([image_encoder_output["img_latents"], self.latents[:, :, 1:, :, :]], dim=2)
helloyongyang's avatar
helloyongyang committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

    def prepare_rotary_pos_embedding(self, video_length, height, width):
        target_ndim = 3
        ndim = 5 - 2
        # 884
        vae = "884-16c-hy"
        patch_size = [1, 2, 2]
        hidden_size = 3072
        heads_num = 24
        rope_theta = 256
        rope_dim_list = [16, 56, 56]
        if "884" in vae:
            latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
        elif "888" in vae:
            latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
        else:
            latents_size = [video_length, height // 8, width // 8]

        if isinstance(patch_size, int):
Dongz's avatar
Dongz committed
297
            assert all(s % patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
helloyongyang's avatar
helloyongyang committed
298
299
            rope_sizes = [s // patch_size for s in latents_size]
        elif isinstance(patch_size, list):
Dongz's avatar
Dongz committed
300
301
            assert all(s % patch_size[idx] == 0 for idx, s in enumerate(latents_size)), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
            rope_sizes = [s // patch_size[idx] for idx, s in enumerate(latents_size)]
helloyongyang's avatar
helloyongyang committed
302
303
304

        if len(rope_sizes) != target_ndim:
            rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
helloyongyang's avatar
helloyongyang committed
305

helloyongyang's avatar
helloyongyang committed
306
        if self.config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            head_dim = hidden_size // heads_num
            rope_dim_list = rope_dim_list
            if rope_dim_list is None:
                rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
            assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
            self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
                rope_dim_list,
                rope_sizes,
                theta=rope_theta,
                use_real=True,
                theta_rescale_factor=1,
            )
            self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
            self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))

        else:
            L_test = rope_sizes[0]  # Latent frames
            L_train = 25  # Training length from HunyuanVideo
            actual_num_frames = video_length  # Use input video_length directly

            head_dim = hidden_size // heads_num
            rope_dim_list = rope_dim_list or [head_dim // target_ndim for _ in range(target_ndim)]
            assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) must equal head_dim"

            if actual_num_frames > 192:
                k = 2 + ((actual_num_frames + 3) // (4 * L_train))
                k = max(4, min(8, k))

                # Compute positional grids for RIFLEx
                axes_grids = [torch.arange(size, device=torch.device("cuda"), dtype=torch.float32) for size in rope_sizes]
                grid = torch.meshgrid(*axes_grids, indexing="ij")
                grid = torch.stack(grid, dim=0)  # [3, t, h, w]
                pos = grid.reshape(3, -1).t()  # [t * h * w, 3]

                # Apply RIFLEx to temporal dimension
                freqs = []
                for i in range(3):
                    if i == 0:  # Temporal with RIFLEx
                        freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=k, L_test=L_test)
                    else:  # Spatial with default RoPE
                        freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=None, L_test=None)
                    freqs.append((freqs_cos, freqs_sin))

                freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
                freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
            else:
                # 20250316 pftq: Original code for <= 192 frames
                freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
                    rope_dim_list,
                    rope_sizes,
                    theta=rope_theta,
                    use_real=True,
                    theta_rescale_factor=1,
                )

            self.freqs_cos = freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
            self.freqs_sin = freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))