utils.py 5.39 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
Xinchi Huang's avatar
Xinchi Huang committed
2
import torch.distributed as dist
gushiqiao's avatar
gushiqiao committed
3
from loguru import logger
gushiqiao's avatar
gushiqiao committed
4
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21


def compute_freqs(c, grid_sizes, freqs):
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
    f, h, w = grid_sizes[0].tolist()
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    return freqs_i

22

gushiqiao's avatar
gushiqiao committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def compute_freqs_audio(c, grid_sizes, freqs):
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
    f, h, w = grid_sizes[0].tolist()
    f = f + 1  ##for r2v add 1 channel
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    return freqs_i


def compute_freqs_audio_dist(s, c, grid_sizes, freqs):
    world_size = dist.get_world_size()
    cur_rank = dist.get_rank()
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
    f, h, w = grid_sizes[0].tolist()
    f = f + 1
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    freqs_i = pad_freqs(freqs_i, s * world_size)
    s_per_rank = s
    freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
    return freqs_i_rank


Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
62
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
    f, h, w = grid_sizes[0].tolist()
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    return freqs_i

helloyongyang's avatar
helloyongyang committed
77

Xinchi Huang's avatar
Xinchi Huang committed
78
79
80
def pad_freqs(original_tensor, target_len):
    seq_len, s1, s2 = original_tensor.shape
    pad_size = target_len - seq_len
Dongz's avatar
Dongz committed
81
    padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
Xinchi Huang's avatar
Xinchi Huang committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
    return padded_tensor


def compute_freqs_dist(s, c, grid_sizes, freqs):
    world_size = dist.get_world_size()
    cur_rank = dist.get_rank()
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
    f, h, w = grid_sizes[0].tolist()
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    freqs_i = pad_freqs(freqs_i, s * world_size)
    s_per_rank = s
Dongz's avatar
Dongz committed
103
    freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
Xinchi Huang's avatar
Xinchi Huang committed
104
105
106
    return freqs_i_rank


helloyongyang's avatar
helloyongyang committed
107
108
109
110
def apply_rotary_emb(x, freqs_i):
    n = x.size(1)
    seq_len = freqs_i.size(0)

Dongz's avatar
Dongz committed
111
    x_i = torch.view_as_complex(x[:seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
helloyongyang's avatar
helloyongyang committed
112
113
    # Apply rotary embedding
    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
gushiqiao's avatar
gushiqiao committed
114
115
    x_i = torch.cat([x_i, x[seq_len:]])
    return x_i.to(torch.bfloat16)
helloyongyang's avatar
helloyongyang committed
116
117


gushiqiao's avatar
gushiqiao committed
118
def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
gushiqiao's avatar
gushiqiao committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    n = x.size(1)
    seq_len = freqs_i.size(0)

    output_chunks = []
    for start in range(0, seq_len, chunk_size):
        end = min(start + chunk_size, seq_len)
        x_chunk = x[start:end]
        freqs_chunk = freqs_i[start:end]

        x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2))
        x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(torch.bfloat16)
        output_chunks.append(x_chunk_embedded)
        del x_chunk_complex, x_chunk_embedded
        torch.cuda.empty_cache()

    result = []
    for chunk in output_chunks:
        result.append(chunk)
    del output_chunks
    torch.cuda.empty_cache()

    for start in range(seq_len, x.size(0), remaining_chunk_size):
        end = min(start + remaining_chunk_size, x.size(0))
        result.append(x[start:end])

    x_i = torch.cat(result, dim=0)
    del result
    torch.cuda.empty_cache()

    return x_i.to(torch.bfloat16)


helloyongyang's avatar
helloyongyang committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def rope_params(max_seq_len, dim, theta=10000):
    assert dim % 2 == 0
    freqs = torch.outer(
        torch.arange(max_seq_len),
        1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
    )
    freqs = torch.polar(torch.ones_like(freqs), freqs)
    return freqs


def sinusoidal_embedding_1d(dim, position):
    # preprocess
    assert dim % 2 == 0
    half = dim // 2
    position = position.type(torch.float64)

    # calculation
Dongz's avatar
Dongz committed
168
    sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
gushiqiao's avatar
gushiqiao committed
169
170
171
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
    if GET_DTYPE() == "BF16":
        x = x.to(torch.bfloat16)
gushiqiao's avatar
gushiqiao committed
172
    return x