utils_bf16.py 951 Bytes
Newer Older
PengGao's avatar
PengGao committed
1
2
from typing import Any, Dict, List, Optional, Tuple, Union

helloyongyang's avatar
helloyongyang committed
3
4
5
6
7
8
9
10
import torch


def rms_norm(x, weight, eps):
    x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
    x = x * weight
    return x

Dongz's avatar
Dongz committed
11

helloyongyang's avatar
helloyongyang committed
12
13
14
15
def rotate_half(x, shape_0, shape_1):
    x_real, x_imag = x.reshape(shape_0, shape_1, -1, 2).unbind(-1)
    return torch.stack([-x_imag, x_real], dim=-1).flatten(2)

Dongz's avatar
Dongz committed
16

helloyongyang's avatar
helloyongyang committed
17
def rotary_emb(x, shape_0, shape_1, cos, sin):
Dongz's avatar
Dongz committed
18
    x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
helloyongyang's avatar
helloyongyang committed
19
20
    return x_out

Dongz's avatar
Dongz committed
21

helloyongyang's avatar
helloyongyang committed
22
23
24
25
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
Dongz's avatar
Dongz committed
26
) -> Tuple[torch.Tensor, torch.Tensor]:
helloyongyang's avatar
helloyongyang committed
27
28
29
30
31
32
    shape_0, shape_1, shape_2 = xq.shape
    cos = freqs_cis[0].view(shape_0, 1, shape_2)
    sin = freqs_cis[1].view(shape_0, 1, shape_2)
    xq_out = rotary_emb(xq, shape_0, shape_1, cos, sin)
    xk_out = rotary_emb(xk, shape_0, shape_1, cos, sin)
    return xq_out, xk_out