utils.py 7.83 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
helloyongyang's avatar
helloyongyang committed
2
import torch.distributed as dist
PengGao's avatar
PengGao committed
3

gushiqiao's avatar
gushiqiao committed
4
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
5
6


wangshankun's avatar
wangshankun committed
7
8
9
10
11
12
13
14
15
16
17
def masks_like(tensor, zero=False, generator=None, p=0.2, prev_length=1):
    assert isinstance(tensor, list)
    out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
    out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]

    if prev_length == 0:
        return out1, out2

    if zero:
        if generator is not None:
            for u, v in zip(out1, out2):
helloyongyang's avatar
fix ci  
helloyongyang committed
18
                random_num = torch.rand(1, generator=generator, device=generator.device).item()
wangshankun's avatar
wangshankun committed
19
                if random_num < p:
helloyongyang's avatar
fix ci  
helloyongyang committed
20
                    u[:, :prev_length] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, :prev_length]).exp()
wangshankun's avatar
wangshankun committed
21
22
23
24
25
26
27
28
29
30
31
32
                    v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
                else:
                    u[:, :prev_length] = u[:, :prev_length]
                    v[:, :prev_length] = v[:, :prev_length]
        else:
            for u, v in zip(out1, out2):
                u[:, :prev_length] = torch.zeros_like(u[:, :prev_length])
                v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])

    return out1, out2


helloyongyang's avatar
helloyongyang committed
33
34
def compute_freqs(c, grid_sizes, freqs):
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
helloyongyang's avatar
helloyongyang committed
35
    f, h, w = grid_sizes[0]
helloyongyang's avatar
helloyongyang committed
36
37
38
39
40
41
42
43
44
45
46
47
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    return freqs_i

48

gushiqiao's avatar
gushiqiao committed
49
50
def compute_freqs_audio(c, grid_sizes, freqs):
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
helloyongyang's avatar
helloyongyang committed
51
    f, h, w = grid_sizes[0]
wangshankun's avatar
wangshankun committed
52
    valid_token_length = f * h * w
gushiqiao's avatar
gushiqiao committed
53
54
55
56
    f = f + 1  ##for r2v add 1 channel
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
57
58
59
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),  # 时间(帧)编码
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),  # 空间(高度)编码
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),  # 空间(宽度)编码
gushiqiao's avatar
gushiqiao committed
60
61
62
63
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

wangshankun's avatar
wangshankun committed
64
65
    freqs_i[valid_token_length:, :, :f] = 0  ###for r2v # zero temporl component corresponding to ref embeddings

gushiqiao's avatar
gushiqiao committed
66
67
68
    return freqs_i


helloyongyang's avatar
helloyongyang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
    world_size = dist.get_world_size(seq_p_group)
    cur_rank = dist.get_rank(seq_p_group)
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
    f, h, w = grid_sizes[0]
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    freqs_i = pad_freqs(freqs_i, s * world_size)
    s_per_rank = s
    freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
    return freqs_i_rank


def compute_freqs_audio_dist(s, c, grid_sizes, freqs, seq_p_group):
    world_size = dist.get_world_size(seq_p_group)
    cur_rank = dist.get_rank(seq_p_group)
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
    f, h, w = grid_sizes[0]
    valid_token_length = f * h * w
    f = f + 1
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    freqs_i[valid_token_length:, :, :f] = 0

    freqs_i = pad_freqs(freqs_i, s * world_size)
    s_per_rank = s
    freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
    return freqs_i_rank


Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
115
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
116
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
helloyongyang's avatar
helloyongyang committed
117
    f, h, w = grid_sizes[0]
118
119
120
121
122
123
124
125
126
127
128
129
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    return freqs_i

helloyongyang's avatar
helloyongyang committed
130

Xinchi Huang's avatar
Xinchi Huang committed
131
132
133
def pad_freqs(original_tensor, target_len):
    seq_len, s1, s2 = original_tensor.shape
    pad_size = target_len - seq_len
Dongz's avatar
Dongz committed
134
    padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
Xinchi Huang's avatar
Xinchi Huang committed
135
136
137
138
    padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
    return padded_tensor


helloyongyang's avatar
helloyongyang committed
139
140
141
142
def apply_rotary_emb(x, freqs_i):
    n = x.size(1)
    seq_len = freqs_i.size(0)

Dongz's avatar
Dongz committed
143
    x_i = torch.view_as_complex(x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
helloyongyang's avatar
helloyongyang committed
144
145
    # Apply rotary embedding
    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
gushiqiao's avatar
gushiqiao committed
146
    x_i = torch.cat([x_i, x[seq_len:]])
147
    return x_i.to(GET_DTYPE())
helloyongyang's avatar
helloyongyang committed
148
149


gushiqiao's avatar
gushiqiao committed
150
def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
gushiqiao's avatar
gushiqiao committed
151
152
153
154
155
156
157
158
159
160
    n = x.size(1)
    seq_len = freqs_i.size(0)

    output_chunks = []
    for start in range(0, seq_len, chunk_size):
        end = min(start + chunk_size, seq_len)
        x_chunk = x[start:end]
        freqs_chunk = freqs_i[start:end]

        x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2))
161
        x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        output_chunks.append(x_chunk_embedded)
        del x_chunk_complex, x_chunk_embedded
        torch.cuda.empty_cache()

    result = []
    for chunk in output_chunks:
        result.append(chunk)
    del output_chunks
    torch.cuda.empty_cache()

    for start in range(seq_len, x.size(0), remaining_chunk_size):
        end = min(start + remaining_chunk_size, x.size(0))
        result.append(x[start:end])

    x_i = torch.cat(result, dim=0)
    del result
    torch.cuda.empty_cache()

180
    return x_i.to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
181
182


helloyongyang's avatar
helloyongyang committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def rope_params(max_seq_len, dim, theta=10000):
    assert dim % 2 == 0
    freqs = torch.outer(
        torch.arange(max_seq_len),
        1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
    )
    freqs = torch.polar(torch.ones_like(freqs), freqs)
    return freqs


def sinusoidal_embedding_1d(dim, position):
    # preprocess
    assert dim % 2 == 0
    half = dim // 2
    position = position.type(torch.float64)

    # calculation
Dongz's avatar
Dongz committed
200
    sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
gushiqiao's avatar
gushiqiao committed
201
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
202
    x = x.to(GET_SENSITIVE_DTYPE())
gushiqiao's avatar
gushiqiao committed
203
    return x
204
205


GoatWu's avatar
GoatWu committed
206
def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32):
207
208
209
210
211
212
213
214
215
216
217
    """
    Args:
    timesteps: torch.Tensor: generate embedding vectors at these timesteps
    embedding_dim: int: dimension of the embeddings to generate
    dtype: data type of the generated embeddings

    Returns:
    embedding vectors with shape `(len(timesteps), embedding_dim)`
    """
    assert len(w.shape) == 1
    cfg_min, cfg_max = cfg_range
218
219
    w = torch.round(w)
    w = torch.clamp(w, min=cfg_min, max=cfg_max)
220
221
222
223
224
225
    w = (w - cfg_min) / (cfg_max - cfg_min)  # [0, 1]
    w = w * target_range
    half_dim = embedding_dim // 2
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device)
    emb = w.to(dtype)[:, None] * emb[None, :]
226
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
227
228
229
230
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1).to(w.device))
    assert emb.shape == (w.shape[0], embedding_dim)
    return emb