utils_fp32.py 1.02 KB
Newer Older
PengGao's avatar
PengGao committed
1
from typing import Tuple, Union
PengGao's avatar
PengGao committed
2

helloyongyang's avatar
helloyongyang committed
3
4
import torch

5
6
from lightx2v.utils.envs import *

helloyongyang's avatar
helloyongyang committed
7
8
9
10

def rms_norm(x, weight, eps):
    x = x.float()
    x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
11
    x = x.to(GET_DTYPE())
helloyongyang's avatar
helloyongyang committed
12
13
14
    x = x * weight
    return x

Dongz's avatar
Dongz committed
15

helloyongyang's avatar
helloyongyang committed
16
17
18
19
def rotate_half(x, shape_0, shape_1):
    x_real, x_imag = x.float().reshape(shape_0, shape_1, -1, 2).unbind(-1)
    return torch.stack([-x_imag, x_real], dim=-1).flatten(2)

Dongz's avatar
Dongz committed
20

helloyongyang's avatar
helloyongyang committed
21
def rotary_emb(x, shape_0, shape_1, cos, sin):
Dongz's avatar
Dongz committed
22
    x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
23
    return x_out.to(GET_DTYPE())
helloyongyang's avatar
helloyongyang committed
24

Dongz's avatar
Dongz committed
25

helloyongyang's avatar
helloyongyang committed
26
27
28
29
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
Dongz's avatar
Dongz committed
30
) -> Tuple[torch.Tensor, torch.Tensor]:
helloyongyang's avatar
helloyongyang committed
31
32
33
34
35
36
    shape_0, shape_1, shape_2 = xq.shape
    cos = freqs_cis[0].view(shape_0, 1, shape_2)
    sin = freqs_cis[1].view(shape_0, 1, shape_2)
    xq_out = rotary_emb(xq.float(), shape_0, shape_1, cos, sin)
    xk_out = rotary_emb(xk.float(), shape_0, shape_1, cos, sin)
    return xq_out, xk_out