scheduler.py 11 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from diffusers.utils.torch_utils import randn_tensor
from typing import Union, Tuple, List
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from lightx2v.text2v.models.schedulers.scheduler import BaseScheduler


def _to_tuple(x, dim=2):
    if isinstance(x, int):
        return (x,) * dim
    elif len(x) == dim:
        return x
    else:
        raise ValueError(f"Expected length {dim} or int, but got {x}")

Dongz's avatar
Dongz committed
16

helloyongyang's avatar
helloyongyang committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def get_1d_rotary_pos_embed(
    dim: int,
    pos: Union[torch.FloatTensor, int],
    theta: float = 10000.0,
    use_real: bool = False,
    theta_rescale_factor: float = 1.0,
    interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Precompute the frequency tensor for complex exponential (cis) with given dimensions.
    (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)

    This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool, optional): If True, return real part and imaginary part separately.
                                   Otherwise, return complex numbers.
        theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.

    Returns:
        freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
        freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
    """
    if isinstance(pos, int):
        pos = torch.arange(pos).float()

    # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
    # has some connection to NTK literature
    if theta_rescale_factor != 1.0:
        theta *= theta_rescale_factor ** (dim / (dim - 2))

Dongz's avatar
Dongz committed
53
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))  # [D/2]
helloyongyang's avatar
helloyongyang committed
54
55
56
57
58
59
60
    # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
    freqs = torch.outer(pos * interpolation_factor, freqs)  # [S, D/2]
    if use_real:
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]
        return freqs_cos, freqs_sin
    else:
Dongz's avatar
Dongz committed
61
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
helloyongyang's avatar
helloyongyang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        return freqs_cis


def get_meshgrid_nd(start, *args, dim=2):
    """
    Get n-D meshgrid with start, stop and num.

    Args:
        start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
            step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
            should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
            n-tuples.
        *args: See above.
        dim (int): Dimension of the meshgrid. Defaults to 2.

    Returns:
        grid (np.ndarray): [dim, ...]
    """
    if len(args) == 0:
        # start is grid_size
        num = _to_tuple(start, dim=dim)
        start = (0,) * dim
        stop = num
    elif len(args) == 1:
        # start is start, args[0] is stop, step is 1
        start = _to_tuple(start, dim=dim)
        stop = _to_tuple(args[0], dim=dim)
        num = [stop[i] - start[i] for i in range(dim)]
    elif len(args) == 2:
        # start is start, args[0] is stop, args[1] is num
        start = _to_tuple(start, dim=dim)  # Left-Top       eg: 12,0
        stop = _to_tuple(args[0], dim=dim)  # Right-Bottom   eg: 20,32
        num = _to_tuple(args[1], dim=dim)  # Target Size    eg: 32,124
    else:
        raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")

    # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
    axis_grid = []
    for i in range(dim):
        a, b, n = start[i], stop[i], num[i]
        g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
        axis_grid.append(g)
    grid = torch.meshgrid(*axis_grid, indexing="ij")  # dim x [W, H, D]
    grid = torch.stack(grid, dim=0)  # [dim, W, H, D]

    return grid

Dongz's avatar
Dongz committed
109

helloyongyang's avatar
helloyongyang committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def get_nd_rotary_pos_embed(
    rope_dim_list,
    start,
    *args,
    theta=10000.0,
    use_real=False,
    theta_rescale_factor: Union[float, List[float]] = 1.0,
    interpolation_factor: Union[float, List[float]] = 1.0,
):
    """
    This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.

    Args:
        rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
            sum(rope_dim_list) should equal to head_dim of attention layer.
        start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
            args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
        *args: See above.
        theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
            Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
            part and an imaginary part separately.
        theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.

    Returns:
        pos_embed (torch.Tensor): [HW, D/2]
    """

Dongz's avatar
Dongz committed
138
    grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))  # [3, W, H, D] / [2, W, H]
helloyongyang's avatar
helloyongyang committed
139
140
141
142
143

    if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
        theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
    elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
        theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
Dongz's avatar
Dongz committed
144
    assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
helloyongyang's avatar
helloyongyang committed
145
146
147
148
149

    if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
        interpolation_factor = [interpolation_factor] * len(rope_dim_list)
    elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
        interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
Dongz's avatar
Dongz committed
150
    assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
helloyongyang's avatar
helloyongyang committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

    # use 1/ndim of dimensions to encode grid_axis
    embs = []
    for i in range(len(rope_dim_list)):
        emb = get_1d_rotary_pos_embed(
            rope_dim_list[i],
            grid[i].reshape(-1),
            theta,
            use_real=use_real,
            theta_rescale_factor=theta_rescale_factor[i],
            interpolation_factor=interpolation_factor[i],
        )  # 2 x [WHD, rope_dim_list[i]]
        embs.append(emb)

    if use_real:
        cos = torch.cat([emb[0] for emb in embs], dim=1)  # (WHD, D/2)
        sin = torch.cat([emb[1] for emb in embs], dim=1)  # (WHD, D/2)
        return cos, sin
    else:
        emb = torch.cat(embs, dim=1)  # (WHD, D/2)
        return emb


Dongz's avatar
Dongz committed
174
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
helloyongyang's avatar
helloyongyang committed
175
176
    sigmas = torch.linspace(1, 0, num_inference_steps + 1)
    sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
Dongz's avatar
Dongz committed
177
    timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.bfloat16, device=device)
helloyongyang's avatar
helloyongyang committed
178
179
180
181
182
183
184
185
    return timesteps, sigmas


class HunyuanScheduler(BaseScheduler):
    def __init__(self, args):
        super().__init__(args)
        self.infer_steps = self.args.infer_steps
        self.shift = 7.0
Dongz's avatar
Dongz committed
186
        self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
helloyongyang's avatar
helloyongyang committed
187
188
        assert len(self.timesteps) == self.infer_steps
        self.embedded_guidance_scale = 6.0
Dongz's avatar
Dongz committed
189
        self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [42]]
helloyongyang's avatar
helloyongyang committed
190
191
192
193
194
195
        self.noise_pred = None
        self.prepare_latents(shape=self.args.target_shape, dtype=torch.bfloat16)
        self.prepare_guidance()
        self.prepare_rotary_pos_embedding(video_length=self.args.target_video_length, height=self.args.target_height, width=self.args.target_width)

    def prepare_guidance(self):
Dongz's avatar
Dongz committed
196
        self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=torch.bfloat16, device=torch.device("cuda")) * 1000.0
helloyongyang's avatar
helloyongyang committed
197
198
199
200
201
202
203
204

    def step_post(self):
        sample = self.latents.to(torch.float32)
        dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
        prev_sample = sample + self.noise_pred.to(torch.float32) * dt
        self.latents = prev_sample

    def prepare_latents(self, shape, dtype):
Dongz's avatar
Dongz committed
205
        self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
helloyongyang's avatar
helloyongyang committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

    def prepare_rotary_pos_embedding(self, video_length, height, width):
        target_ndim = 3
        ndim = 5 - 2
        # 884
        vae = "884-16c-hy"
        patch_size = [1, 2, 2]
        hidden_size = 3072
        heads_num = 24
        rope_theta = 256
        rope_dim_list = [16, 56, 56]
        if "884" in vae:
            latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
        elif "888" in vae:
            latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
        else:
            latents_size = [video_length, height // 8, width // 8]

        if isinstance(patch_size, int):
Dongz's avatar
Dongz committed
225
            assert all(s % patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
helloyongyang's avatar
helloyongyang committed
226
227
            rope_sizes = [s // patch_size for s in latents_size]
        elif isinstance(patch_size, list):
Dongz's avatar
Dongz committed
228
229
            assert all(s % patch_size[idx] == 0 for idx, s in enumerate(latents_size)), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
            rope_sizes = [s // patch_size[idx] for idx, s in enumerate(latents_size)]
helloyongyang's avatar
helloyongyang committed
230
231
232
233
234
235
236

        if len(rope_sizes) != target_ndim:
            rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes  # time axis
        head_dim = hidden_size // heads_num
        rope_dim_list = rope_dim_list
        if rope_dim_list is None:
            rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
Dongz's avatar
Dongz committed
237
        assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
helloyongyang's avatar
helloyongyang committed
238
239
240
241
242
243
244
245
246
        self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
            rope_dim_list,
            rope_sizes,
            theta=rope_theta,
            use_real=True,
            theta_rescale_factor=1,
        )
        self.freqs_cos = self.freqs_cos.to(dtype=torch.bfloat16, device=torch.device("cuda"))
        self.freqs_sin = self.freqs_sin.to(dtype=torch.bfloat16, device=torch.device("cuda"))