test_patched_rotary.py 2.47 KB
Newer Older
huangwb's avatar
huangwb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import rotary_emb
from vllm import pos_encoding_ops

def apply_rotary_eager(query, key, cos, sin):
    def _apply_rot(x, cos, sin):
        rotary_dim = cos.shape[-1]

        dtype = x.dtype
        x_upcast = x.to(torch.float32)
        cos = cos.to(torch.float32)
        sin = sin.to(torch.float32)

        x1 = x_upcast[..., :rotary_dim]
        x2 = x_upcast[..., rotary_dim : 2 * rotary_dim]

        # Flash Attention rotary_emb kernel casts everything to float, not sure why, so we do so here as well.
        x[..., :rotary_dim] = (x1 * cos - x2 * sin).to(dtype)
        x[..., rotary_dim : 2 * rotary_dim] = (x1 * sin + x2 * cos).to(dtype)

    _apply_rot(query, cos, sin)
    _apply_rot(key, cos, sin)

def apply_rotary_flash(query, key, cos, sin):
    def _apply_rot(x, cos, sin):
        rotary_dim = cos.shape[-1]
        x1 = x[..., :rotary_dim]
        x2 = x[..., rotary_dim : 2 * rotary_dim]

        rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)

    _apply_rot(query, cos, sin)
    _apply_rot(key, cos, sin)

def apply_rotary_vllm(query, key, cos, sin):
    head_size = query.shape[-1]

    #print("query", query.dtype)
    #print("key", key.dtype)
    #print("cos", cos.dtype)
    #print("sin", sin.dtype)

    # Inplace operation, updating query and key.
    pos_encoding_ops.rotary_embedding(
        query,
        key,
        head_size,
        cos,
        sin,
        True
    )

seqlen = 8
cos = torch.rand(seqlen, 1, 64).to("cuda").to(torch.float16)
sin = torch.rand(seqlen, 1, 64).to("cuda").to(torch.float16)

head_dim = 128
num_heads = 32

query_eager = torch.rand(seqlen, num_heads, head_dim).to(torch.float16).to("cuda")
key_eager = torch.rand(seqlen, num_heads, head_dim).to(torch.float16).to("cuda")

query_vllm = query_eager.clone()
query_flash = query_eager.clone()

key_vllm = key_eager.clone()
key_flash = key_eager.clone()

apply_rotary_eager(query_eager, key_eager, cos.clone(), sin.clone())
apply_rotary_flash(query_flash, key_flash, cos.clone(), sin.clone())
apply_rotary_vllm(query_vllm, key_vllm, cos.clone().float(), sin.clone().float())

def check_diff(a, b, a_name, b_name):
    print(f"Allclose {a_name}, {b_name}: {torch.allclose(a, b)}; Abs reldiff: {((a - b).abs() / (a.abs() + 1e-12)).mean()}")

check_diff(query_eager, query_vllm, "query_eager", "query_vllm")
check_diff(query_eager, query_flash, "query_eager", "query_flash")
check_diff(key_eager, key_vllm, "key_eager", "key_vllm")
check_diff(key_eager, key_flash, "key_eager", "key_flash")