"sgl-kernel/csrc/common_extension_rocm.cc" did not exist on "af6535e7aaf5c1e9352149f0edfde37d977cd473"
utils.py 7.32 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
helloyongyang's avatar
helloyongyang committed
2
import torch.distributed as dist
3
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
PengGao's avatar
PengGao committed
4

gushiqiao's avatar
gushiqiao committed
5
from lightx2v.utils.envs import *
helloyongyang's avatar
helloyongyang committed
6
7


8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def apply_wan_rope_with_torch(
    xq: torch.Tensor,
    xk: torch.Tensor,
    cos_sin_cache: torch.Tensor,
):
    n = xq.size(1)
    seq_len = cos_sin_cache.size(0)

    xq = torch.view_as_complex(xq[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2))
    xk = torch.view_as_complex(xk[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2))
    # Apply rotary embedding
    xq = torch.view_as_real(xq * cos_sin_cache).flatten(2)
    xk = torch.view_as_real(xk * cos_sin_cache).flatten(2)
    xq = torch.cat([xq, xq[seq_len:]])
    xk = torch.cat([xk, xk[seq_len:]])

    return xq.to(GET_DTYPE()), xk.to(GET_DTYPE())


def apply_wan_rope_with_chunk(
    xq: torch.Tensor,
    xk: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    chunk_size: int,
    rope_func,
):
    seq_len = cos_sin_cache.size(0)

    xq_output_chunks = []
    xk_output_chunks = []
    for start in range(0, seq_len, chunk_size):
        end = min(start + chunk_size, seq_len)
        xq_chunk = xq[start:end]
        xk_chunk = xk[start:end]
        cos_sin_chunk = cos_sin_cache[start:end]

        xq_chunk, xk_chunk = rope_func(xq_chunk, xk_chunk, cos_sin_chunk)
        xq_output_chunks.append(xq_chunk)
        xk_output_chunks.append(xk_chunk)
        torch.cuda.empty_cache()

    x_q = torch.cat(xq_output_chunks, dim=0)
    del xq_output_chunks
    torch.cuda.empty_cache()

    x_k = torch.cat(xk_output_chunks, dim=0)
    del xk_output_chunks
    torch.cuda.empty_cache()

    return x_q.to(GET_DTYPE()), x_k.to(GET_DTYPE())


def apply_wan_rope_with_flashinfer(
    xq: torch.Tensor,
    xk: torch.Tensor,
    cos_sin_cache: torch.Tensor,
):
    L, H, D = xq.shape

    query = xq.reshape(L, H * D).contiguous()
    key = xk.reshape(L, H * D).contiguous()

    positions = torch.arange(L, device="cpu", dtype=torch.long).to(xq.device, non_blocking=True)

    apply_rope_with_cos_sin_cache_inplace(
        positions=positions,
        query=query,
        key=key,
        head_size=D,
        cos_sin_cache=cos_sin_cache,
        is_neox=False,
    )

    xq_out = query.view(L, H, D)
    xk_out = key.view(L, H, D)
    return xq_out, xk_out


helloyongyang's avatar
helloyongyang committed
86
87
def compute_freqs(c, grid_sizes, freqs):
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
88
    f, h, w = grid_sizes
helloyongyang's avatar
helloyongyang committed
89
90
91
92
93
94
95
96
97
98
99
100
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    return freqs_i

101

helloyongyang's avatar
helloyongyang committed
102
103
104
105
def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
    world_size = dist.get_world_size(seq_p_group)
    cur_rank = dist.get_rank(seq_p_group)
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
106
    f, h, w = grid_sizes
helloyongyang's avatar
helloyongyang committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    freqs_i = pad_freqs(freqs_i, s * world_size)
    s_per_rank = s
    freqs_i_rank = freqs_i[(cur_rank * s_per_rank) : ((cur_rank + 1) * s_per_rank), :, :]
    return freqs_i_rank


Zhuguanyu Wu's avatar
Zhuguanyu Wu committed
123
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
124
    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
125
    f, h, w = grid_sizes
126
127
128
129
130
131
132
133
134
135
136
137
    seq_len = f * h * w
    freqs_i = torch.cat(
        [
            freqs[0][start_frame : start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
        ],
        dim=-1,
    ).reshape(seq_len, 1, -1)

    return freqs_i

helloyongyang's avatar
helloyongyang committed
138

Xinchi Huang's avatar
Xinchi Huang committed
139
140
141
def pad_freqs(original_tensor, target_len):
    seq_len, s1, s2 = original_tensor.shape
    pad_size = target_len - seq_len
Dongz's avatar
Dongz committed
142
    padding_tensor = torch.ones(pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device)
Xinchi Huang's avatar
Xinchi Huang committed
143
144
145
146
    padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
    return padded_tensor


helloyongyang's avatar
helloyongyang committed
147
148
149
150
def apply_rotary_emb(x, freqs_i):
    n = x.size(1)
    seq_len = freqs_i.size(0)

gushiqiao's avatar
gushiqiao committed
151
    x_i = torch.view_as_complex(x[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2))
helloyongyang's avatar
helloyongyang committed
152
153
    # Apply rotary embedding
    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
gushiqiao's avatar
gushiqiao committed
154
    x_i = torch.cat([x_i, x[seq_len:]])
155
    return x_i.to(GET_DTYPE())
helloyongyang's avatar
helloyongyang committed
156
157


gushiqiao's avatar
gushiqiao committed
158
def apply_rotary_emb_chunk(x, freqs_i, chunk_size, remaining_chunk_size=100):
gushiqiao's avatar
gushiqiao committed
159
160
161
162
163
164
165
166
167
168
    n = x.size(1)
    seq_len = freqs_i.size(0)

    output_chunks = []
    for start in range(0, seq_len, chunk_size):
        end = min(start + chunk_size, seq_len)
        x_chunk = x[start:end]
        freqs_chunk = freqs_i[start:end]

        x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2))
169
        x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        output_chunks.append(x_chunk_embedded)
        del x_chunk_complex, x_chunk_embedded
        torch.cuda.empty_cache()

    result = []
    for chunk in output_chunks:
        result.append(chunk)
    del output_chunks
    torch.cuda.empty_cache()

    for start in range(seq_len, x.size(0), remaining_chunk_size):
        end = min(start + remaining_chunk_size, x.size(0))
        result.append(x[start:end])

    x_i = torch.cat(result, dim=0)
    del result
    torch.cuda.empty_cache()

188
    return x_i.to(GET_DTYPE())
gushiqiao's avatar
gushiqiao committed
189
190


helloyongyang's avatar
helloyongyang committed
191
192
193
194
def rope_params(max_seq_len, dim, theta=10000):
    assert dim % 2 == 0
    freqs = torch.outer(
        torch.arange(max_seq_len),
gushiqiao's avatar
gushiqiao committed
195
        1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)),
helloyongyang's avatar
helloyongyang committed
196
197
198
199
200
201
202
203
204
    )
    freqs = torch.polar(torch.ones_like(freqs), freqs)
    return freqs


def sinusoidal_embedding_1d(dim, position):
    # preprocess
    assert dim % 2 == 0
    half = dim // 2
gushiqiao's avatar
gushiqiao committed
205
    position = position.type(torch.float32)
helloyongyang's avatar
helloyongyang committed
206
207

    # calculation
Dongz's avatar
Dongz committed
208
    sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
gushiqiao's avatar
gushiqiao committed
209
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
210
    x = x.to(GET_SENSITIVE_DTYPE())
gushiqiao's avatar
gushiqiao committed
211
    return x
212
213


GoatWu's avatar
GoatWu committed
214
def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32):
215
216
217
218
219
220
221
222
223
224
225
    """
    Args:
    timesteps: torch.Tensor: generate embedding vectors at these timesteps
    embedding_dim: int: dimension of the embeddings to generate
    dtype: data type of the generated embeddings

    Returns:
    embedding vectors with shape `(len(timesteps), embedding_dim)`
    """
    assert len(w.shape) == 1
    cfg_min, cfg_max = cfg_range
226
227
    w = torch.round(w)
    w = torch.clamp(w, min=cfg_min, max=cfg_max)
228
229
230
231
232
233
    w = (w - cfg_min) / (cfg_max - cfg_min)  # [0, 1]
    w = w * target_range
    half_dim = embedding_dim // 2
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=dtype).to(w.device) * -emb).to(w.device)
    emb = w.to(dtype)[:, None] * emb[None, :]
234
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
235
236
237
238
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1).to(w.device))
    assert emb.shape == (w.shape[0], embedding_dim)
    return emb