test_rotary_embedding.py 4.39 KB
Newer Older
1
from typing import Any, Dict, List, Optional, Tuple, Union
2

3
import pytest
4
import torch
5
6
7
8
9
10
11
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
from sgl_kernel.testing.rotary_embedding import (
    FlashInferRotaryEmbedding,
    MHATokenToKVPool,
    RotaryEmbedding,
    create_inputs,
)
12
13


14
@pytest.mark.parametrize(
15
    "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads, save_kv_cache",
16
    [
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        # GPT-OSS cases
        *[
            (
                64,
                64,
                4096,
                8000,
                True,
                torch.bfloat16,
                "cuda",
                batch_size,
                seq_len,
                64,
                8,
                save_kv_cache,
            )
            for batch_size, seq_len in (
                (1, 1),
                (32, 1),
                (128, 1),
                (512, 1),
                (2, 512),
                (4, 4096),
            )
            for save_kv_cache in (False, True)
        ],
        # Other cases
        (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1, False),
        (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2, False),
        (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
        (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8, False),
        (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4, False),
        (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2, False),
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    ],
)
def test_correctness(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: int,
    is_neox_style: bool,
    dtype: torch.dtype,
    device: str,
    batch_size: int,
    seq_len: int,
    num_q_heads: int,
    num_kv_heads: int,
64
    save_kv_cache: bool,
65
):
66
67
68
69
70
71
72
    config = dict(
        head_size=head_size,
        rotary_dim=rotary_dim,
        max_position_embeddings=max_position_embeddings,
        base=base,
        is_neox_style=is_neox_style,
        dtype=dtype,
73
    )
74
75
76
77
78
79
80
81
82
83
84
85

    rope_ref = RotaryEmbedding(**config).to(device)
    rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)

    inputs = create_inputs(
        head_size=head_size,
        batch_size=batch_size,
        seq_len=seq_len,
        device=device,
        dtype=dtype,
        num_q_heads=num_q_heads,
        num_kv_heads=num_kv_heads,
86
87
    )

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    if save_kv_cache:
        pool_ref = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)
        pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)

    query_ref, key_ref = inputs["query"].clone(), inputs["key"].clone()
    query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()

    query_ref_out, key_ref_out = rope_ref.forward_native(
        inputs["pos_ids"], query_ref, key_ref
    )
    if save_kv_cache:
        pool_ref.set_kv_buffer(
            loc=inputs["out_cache_loc"],
            cache_k=key_ref_out.view(-1, num_kv_heads, head_size),
            cache_v=inputs["value"].view(-1, num_kv_heads, head_size),
        )
104
105

    query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda(
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        inputs["pos_ids"],
        query_flashinfer,
        key_flashinfer,
        fused_set_kv_buffer_arg=(
            FusedSetKVBufferArg(
                value=inputs["value"],
                k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
                v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
                k_scale=None,
                v_scale=None,
                cache_loc=inputs["out_cache_loc"],
            )
            if save_kv_cache
            else None
        ),
121
122
    )

123
124
125
126
    torch.testing.assert_close(
        query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2
    )
    torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2)
127
128
129
130
131
132
133
134
    if save_kv_cache:
        for field in ["k_buffer", "v_buffer"]:
            x_ref = getattr(pool_ref, field)[0]
            x_flashinfer = getattr(pool_flashinfer, field)[0]
            torch.testing.assert_close(x_ref, x_flashinfer, atol=1e-2, rtol=1e-2)
            nonzero_ref = x_ref != 0
            nonzero_flashinfer = x_ref != 0
            assert torch.all(nonzero_ref == nonzero_flashinfer)
135
136
137
138


if __name__ == "__main__":
    pytest.main([__file__])