scheduler.py 16.7 KB
Newer Older
PengGao's avatar
PengGao committed
1
from typing import List, Optional, Tuple, Union
PengGao's avatar
PengGao committed
2

helloyongyang's avatar
helloyongyang committed
3
import numpy as np
PengGao's avatar
PengGao committed
4
import torch
helloyongyang's avatar
helloyongyang committed
5
from diffusers.utils.torch_utils import randn_tensor
PengGao's avatar
PengGao committed
6

7
from lightx2v.models.schedulers.scheduler import BaseScheduler
8
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
9
10
11
12
13
14
15
16
17
18


def _to_tuple(x, dim=2):
    if isinstance(x, int):
        return (x,) * dim
    elif len(x) == dim:
        return x
    else:
        raise ValueError(f"Expected length {dim} or int, but got {x}")

Dongz's avatar
Dongz committed
19

helloyongyang's avatar
helloyongyang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def get_1d_rotary_pos_embed(
    dim: int,
    pos: Union[torch.FloatTensor, int],
    theta: float = 10000.0,
    use_real: bool = False,
    theta_rescale_factor: float = 1.0,
    interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Precompute the frequency tensor for complex exponential (cis) with given dimensions.
    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)

    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool, optional): If True, return real part and imaginary part separately.
                                   Otherwise, return complex numbers.
        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.

    Returns:
        freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
    """
    if isinstance(pos, int):
        pos = torch.arange(pos).float()

    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
    # has some connection to NTK literature
    if theta_rescale_factor != 1.0:
        theta *= theta_rescale_factor ** (dim / (dim - 2))

Dongz's avatar
Dongz committed
56
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  # [D/2]
helloyongyang's avatar
helloyongyang committed
57
58
59
60
61
62
63
    # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]
    if use_real:
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]
        return freqs_cos, freqs_sin
    else:
Dongz's avatar
Dongz committed
64
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
helloyongyang's avatar
helloyongyang committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        return freqs_cis


def get_meshgrid_nd(start, *args, dim=2):
    """
    Get n-D meshgrid with start, stop and num.

    Args:
        start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
            step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
            should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
            n-tuples.
        *args: See above.
        dim (int): Dimension of the meshgrid. Defaults to 2.

    Returns:
        grid (np.ndarray): [dim, ...]
    """
    if len(args) == 0:
        # start is grid_size
        num = _to_tuple(start, dim=dim)
        start = (0,) * dim
        stop = num
    elif len(args) == 1:
        # start is start, args[0] is stop, step is 1
        start = _to_tuple(start, dim=dim)
        stop = _to_tuple(args[0], dim=dim)
        num = [stop[i] - start[i] for i in range(dim)]
    elif len(args) == 2:
        # start is start, args[0] is stop, args[1] is num
        start = _to_tuple(start, dim=dim)  # Left-Top       eg: 12,0
        stop = _to_tuple(args[0], dim=dim)  # Right-Bottom   eg: 20,32
        num = _to_tuple(args[1], dim=dim)  # Target Size    eg: 32,124
    else:
        raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")

    # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
    axis_grid = []
    for i in range(dim):
        a, b, n = start[i], stop[i], num[i]
        g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
        axis_grid.append(g)
    grid = torch.meshgrid(*axis_grid, indexing="ij")  # dim x [W, H, D]
    grid = torch.stack(grid, dim=0)  # [dim, W, H, D]

    return grid

Dongz's avatar
Dongz committed
112

helloyongyang's avatar
helloyongyang committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def get_nd_rotary_pos_embed(
    rope_dim_list,
    start,
    *args,
    theta=10000.0,
    use_real=False,
    theta_rescale_factor: Union[float, List[float]] = 1.0,
    interpolation_factor: Union[float, List[float]] = 1.0,
):
    """
    This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.

    Args:
        rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
            sum(rope_dim_list) should equal to head_dim of attention layer.
        start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
            args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
        *args: See above.
        theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
            Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
            part and an imaginary part separately.
        theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.

    Returns:
        pos_embed (torch.Tensor): [HW, D/2]
    """

Dongz's avatar
Dongz committed
141
    grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))  # [3, W, H, D] / [2, W, H]
helloyongyang's avatar
helloyongyang committed
142
143
144
145
146

    if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
    elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
Dongz's avatar
Dongz committed
147
    assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
helloyongyang's avatar
helloyongyang committed
148
149
150
151
152

    if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
    elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
Dongz's avatar
Dongz committed
153
    assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
helloyongyang's avatar
helloyongyang committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    # use 1/ndim of dimensions to encode grid_axis
    embs = []
    for i in range(len(rope_dim_list)):
        emb = get_1d_rotary_pos_embed(
            rope_dim_list[i],
            grid[i].reshape(-1),
            theta,
            use_real=use_real,
            theta_rescale_factor=theta_rescale_factor[i],
            interpolation_factor=interpolation_factor[i],
        )  # 2 x [WHD, rope_dim_list[i]]
        embs.append(emb)

    if use_real:
        cos = torch.cat([emb[0] for emb in embs], dim=1)  # (WHD, D/2)
        sin = torch.cat([emb[1] for emb in embs], dim=1)  # (WHD, D/2)
        return cos, sin
    else:
        emb = torch.cat(embs, dim=1)  # (WHD, D/2)
        return emb


Dongz's avatar
Dongz committed
177
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
helloyongyang's avatar
helloyongyang committed
178
179
    sigmas = torch.linspace(1, 0, num_inference_steps + 1)
    sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
helloyongyang's avatar
helloyongyang committed
180
    timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32, device=device)
helloyongyang's avatar
helloyongyang committed
181
182
183
    return timesteps, sigmas


helloyongyang's avatar
helloyongyang committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def get_1d_rotary_pos_embed_riflex(
    dim: int,
    pos: Union[np.ndarray, int],
    theta: float = 10000.0,
    use_real=False,
    k: Optional[int] = None,
    L_test: Optional[int] = None,
):
    """
    RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
    index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
    data type.

    Args:
        dim (`int`): Dimension of the frequency tensor.
        pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
        theta (`float`, *optional*, defaults to 10000.0):
            Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (`bool`, *optional*):
            If True, return real part and imaginary part separately. Otherwise, return complex numbers.
        k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
        L_test (`int`, *optional*, defaults to None): the number of frames for inference
    Returns:
        `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
    """
    assert dim % 2 == 0

    if isinstance(pos, int):
        pos = torch.arange(pos)
    if isinstance(pos, np.ndarray):
Watebear's avatar
Watebear committed
216
        pos = torch.from_numpy(pos)  # [S]
helloyongyang's avatar
helloyongyang committed
217
218
219
220
221
222
223
224
225
226
227

    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim))  # [D/2]

    # === Riflex modification start ===
    # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
    # Empirical observations show that a few videos may exhibit repetition in the tail frames.
    # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
    if k is not None:
        freqs[k - 1] = 0.9 * 2 * torch.pi / L_test
    # === Riflex modification end ===

Watebear's avatar
Watebear committed
228
    freqs = torch.outer(pos, freqs)  # [S, D/2]
helloyongyang's avatar
helloyongyang committed
229
230
231
232
233
234
235
236
237
238
    if use_real:
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float()  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float()  # [S, D]
        return freqs_cos, freqs_sin
    else:
        # lumina
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
        return freqs_cis


helloyongyang's avatar
helloyongyang committed
239
class HunyuanScheduler(BaseScheduler):
helloyongyang's avatar
helloyongyang committed
240
241
    def __init__(self, config):
        super().__init__(config)
helloyongyang's avatar
helloyongyang committed
242
        self.shift = 7.0
Dongz's avatar
Dongz committed
243
        self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
244
245
        assert len(self.timesteps) == self.infer_steps
        self.embedded_guidance_scale = 6.0
helloyongyang's avatar
helloyongyang committed
246
        self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [self.config.seed]]
helloyongyang's avatar
helloyongyang committed
247
        self.noise_pred = None
helloyongyang's avatar
helloyongyang committed
248
249
250

    def prepare(self, image_encoder_output):
        self.image_encoder_output = image_encoder_output
251
        self.prepare_latents(shape=self.config.target_shape, dtype=torch.float32, image_encoder_output=image_encoder_output)
helloyongyang's avatar
helloyongyang committed
252
        self.prepare_guidance()
helloyongyang's avatar
helloyongyang committed
253
        self.prepare_rotary_pos_embedding(video_length=self.config.target_video_length, height=self.config.target_height, width=self.config.target_width)
helloyongyang's avatar
helloyongyang committed
254
255

    def prepare_guidance(self):
256
        self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=GET_DTYPE(), device=torch.device("cuda")) * 1000.0
helloyongyang's avatar
helloyongyang committed
257
258

    def step_post(self):
helloyongyang's avatar
helloyongyang committed
259
        if self.config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
260
261
262
263
264
265
266
267
            sample = self.latents.to(torch.float32)
            dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
            self.latents = sample + self.noise_pred.to(torch.float32) * dt
        else:
            sample = self.latents[:, :, 1:, :, :].to(torch.float32)
            dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
            latents = sample + self.noise_pred[:, :, 1:, :, :].to(torch.float32) * dt
            self.latents = torch.concat([self.image_encoder_output["img_latents"], latents], dim=2)
helloyongyang's avatar
helloyongyang committed
268

helloyongyang's avatar
helloyongyang committed
269
270
    def prepare_latents(self, shape, dtype, image_encoder_output):
        if self.config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
271
272
            self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
        else:
helloyongyang's avatar
helloyongyang committed
273
            x1 = image_encoder_output["img_latents"].repeat(1, 1, (self.config.target_video_length - 1) // 4 + 1, 1, 1)
helloyongyang's avatar
helloyongyang committed
274
275
276
277
            x0 = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
            t = torch.tensor([0.999]).to(device=torch.device("cuda"))
            self.latents = x0 * t + x1 * (1 - t)
            self.latents = self.latents.to(dtype=dtype)
helloyongyang's avatar
helloyongyang committed
278
            self.latents = torch.concat([image_encoder_output["img_latents"], self.latents[:, :, 1:, :, :]], dim=2)
helloyongyang's avatar
helloyongyang committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

    def prepare_rotary_pos_embedding(self, video_length, height, width):
        target_ndim = 3
        ndim = 5 - 2
        # 884
        vae = "884-16c-hy"
        patch_size = [1, 2, 2]
        hidden_size = 3072
        heads_num = 24
        rope_theta = 256
        rope_dim_list = [16, 56, 56]
        if "884" in vae:
            latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
        elif "888" in vae:
            latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
        else:
            latents_size = [video_length, height // 8, width // 8]

        if isinstance(patch_size, int):
Dongz's avatar
Dongz committed
298
            assert all(s % patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
helloyongyang's avatar
helloyongyang committed
299
300
            rope_sizes = [s // patch_size for s in latents_size]
        elif isinstance(patch_size, list):
Dongz's avatar
Dongz committed
301
302
            assert all(s % patch_size[idx] == 0 for idx, s in enumerate(latents_size)), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
            rope_sizes = [s // patch_size[idx] for idx, s in enumerate(latents_size)]
helloyongyang's avatar
helloyongyang committed
303
304
305

        if len(rope_sizes) != target_ndim:
            rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
helloyongyang's avatar
helloyongyang committed
306

helloyongyang's avatar
helloyongyang committed
307
        if self.config.task == "t2v":
helloyongyang's avatar
helloyongyang committed
308
309
310
311
312
313
314
315
316
317
318
319
            head_dim = hidden_size // heads_num
            rope_dim_list = rope_dim_list
            if rope_dim_list is None:
                rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
            assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
            self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
                rope_dim_list,
                rope_sizes,
                theta=rope_theta,
                use_real=True,
                theta_rescale_factor=1,
            )
320
321
            self.freqs_cos = self.freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
            self.freqs_sin = self.freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362

        else:
            L_test = rope_sizes[0]  # Latent frames
            L_train = 25  # Training length from HunyuanVideo
            actual_num_frames = video_length  # Use input video_length directly

            head_dim = hidden_size // heads_num
            rope_dim_list = rope_dim_list or [head_dim // target_ndim for _ in range(target_ndim)]
            assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) must equal head_dim"

            if actual_num_frames > 192:
                k = 2 + ((actual_num_frames + 3) // (4 * L_train))
                k = max(4, min(8, k))

                # Compute positional grids for RIFLEx
                axes_grids = [torch.arange(size, device=torch.device("cuda"), dtype=torch.float32) for size in rope_sizes]
                grid = torch.meshgrid(*axes_grids, indexing="ij")
                grid = torch.stack(grid, dim=0)  # [3, t, h, w]
                pos = grid.reshape(3, -1).t()  # [t * h * w, 3]

                # Apply RIFLEx to temporal dimension
                freqs = []
                for i in range(3):
                    if i == 0:  # Temporal with RIFLEx
                        freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=k, L_test=L_test)
                    else:  # Spatial with default RoPE
                        freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=None, L_test=None)
                    freqs.append((freqs_cos, freqs_sin))

                freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
                freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
            else:
                # 20250316 pftq: Original code for <= 192 frames
                freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
                    rope_dim_list,
                    rope_sizes,
                    theta=rope_theta,
                    use_real=True,
                    theta_rescale_factor=1,
                )

363
364
            self.freqs_cos = freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
            self.freqs_sin = freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))