utils_fp32.py 1.02 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
from typing import Any, Dict, List, Optional, Tuple, Union

helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
10
11
12
import torch


def rms_norm(x, weight, eps):
    x = x.float()
    x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
    x = x.to(torch.bfloat16)
    x = x * weight
    return x

Dongz's avatar
Dongz committed
13

helloyongyang's avatar
helloyongyang committed
14
15
16
17
def rotate_half(x, shape_0, shape_1):
    x_real, x_imag = x.float().reshape(shape_0, shape_1, -1, 2).unbind(-1)
    return torch.stack([-x_imag, x_real], dim=-1).flatten(2)

Dongz's avatar
Dongz committed
18

helloyongyang's avatar
helloyongyang committed
19
def rotary_emb(x, shape_0, shape_1, cos, sin):
Dongz's avatar
Dongz committed
20
    x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
helloyongyang's avatar
helloyongyang committed
21
22
    return x_out.to(torch.bfloat16)

Dongz's avatar
Dongz committed
23

helloyongyang's avatar
helloyongyang committed
24
25
26
27
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
Dongz's avatar
Dongz committed
28
) -> Tuple[torch.Tensor, torch.Tensor]:
helloyongyang's avatar
helloyongyang committed
29
30
31
32
33
34
    shape_0, shape_1, shape_2 = xq.shape
    cos = freqs_cis[0].view(shape_0, 1, shape_2)
    sin = freqs_cis[1].view(shape_0, 1, shape_2)
    xq_out = rotary_emb(xq.float(), shape_0, shape_1, cos, sin)
    xk_out = rotary_emb(xk.float(), shape_0, shape_1, cos, sin)
    return xq_out, xk_out