test_rope.py 5.68 KB
Newer Older
1
2
3
4
5
6
import unittest

import sgl_kernel
import torch
from utils import precision

7
8
9
10
from sglang.srt.layers.rotary_embedding import (
    DeepseekScalingRotaryEmbedding,
    RotaryEmbedding,
)
11
12
from sglang.test.test_utils import CustomTestCase

13
14
torch.manual_seed(0)

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69

class TestROPE(CustomTestCase):
    def test_deepseek_v2_rope(self):
        num_head = 16
        seq_len = 1024
        q_head_dim = 192
        qk_nope_head_dim = 128
        qk_rope_head_dim = 64
        max_pos = 256
        k_dim = 576
        rotary_dim = 64
        is_neox_style = False

        # Create cos_sin_cache
        freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
        cos = freqs.cos() * 0.7
        sin = freqs.sin() * 0.7
        cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
        positions = torch.randint(0, max_pos, (seq_len,))

        rope = DeepseekScalingRotaryEmbedding(
            qk_rope_head_dim,
            rotary_dim,
            max_pos,
            16,  # not used since cos_sin_cache is provided
            is_neox_style,
            1.0,
            torch.bfloat16,
            device="cpu",
        )
        rope.register_buffer("cos_sin_cache", cos_sin_cache)

        for dtype in [torch.bfloat16]:
            enable_autocast = True

            with torch.no_grad(), torch.amp.autocast("cpu", enabled=enable_autocast):
                q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
                q_clone = q.clone()
                k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
                k_clone = k.clone()
                _, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
                _, q_pe_clone = q_clone.split(
                    [qk_nope_head_dim, qk_rope_head_dim], dim=-1
                )
                k_pe = k[:, :, k_dim - qk_rope_head_dim :]
                k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]

                # ref kernel
                q_pe, k_pe = rope.forward_native(
                    query=q_pe,
                    key=k_pe,
                    positions=positions,
                )

                # fused rope kernel
70
71
72
73
74
75
76
                q_pe_clone, k_pe_clone = torch.ops.sgl_kernel.rotary_embedding_cpu(
                    positions,
                    q_pe_clone,
                    k_pe_clone,
                    rope.head_size,
                    cos_sin_cache,
                    False,
77
78
79
                )

                atol = rtol = precision[q_pe.dtype]
80
81
                torch.testing.assert_close(q_pe, q_pe_clone, atol=atol, rtol=rtol)
                torch.testing.assert_close(k_pe, k_pe_clone, atol=atol, rtol=rtol)
82
83
                torch.testing.assert_close(k_pe, k_pe_clone)

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def test_origin_rope(self):
        def single_test(
            head_size: int,
            rotary_dim: int,
            max_position_embeddings: int,
            base: int,
            is_neox_style: bool,
            dtype: torch.dtype,
            device: str,
            batch_size: int,
            seq_len: int,
            num_q_heads: int,
            num_kv_heads: int,
        ):
            torch.manual_seed(100)
            rope_ref = RotaryEmbedding(
                head_size,
                rotary_dim,
                max_position_embeddings,
                base,
                is_neox_style,
                dtype,
            ).to(device)
            pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
            query = torch.randn(
                batch_size * seq_len,
                num_q_heads * head_size,
                dtype=dtype,
                device=device,
            )
            key = torch.randn(
                batch_size * seq_len,
                num_kv_heads * head_size,
                dtype=dtype,
                device=device,
            )

            query_ref, key_ref = query.clone(), key.clone()
            query_cpu, key_cpu = query.clone(), key.clone()

            query_ref_out, key_ref_out = rope_ref.forward_native(
                pos_ids, query_ref, key_ref
            )
            query_cpu_out, key_cpu_out = torch.ops.sgl_kernel.rotary_embedding_cpu(
                pos_ids,
                query_cpu,
                key_cpu,
                rope_ref.head_size,
                rope_ref.cos_sin_cache.to(query.dtype),
                rope_ref.is_neox_style,
            )
            torch.testing.assert_close(
                query_ref_out, query_cpu_out, atol=1e-2, rtol=1e-2
            )
            torch.testing.assert_close(key_ref_out, key_cpu_out, atol=1e-2, rtol=1e-2)

        test_config = [
            (64, 64, 32, 8000, True, torch.bfloat16, "cpu", 32, 32, 1, 1),
            (256, 128, 4096, 10000, True, torch.bfloat16, "cpu", 2, 512, 32, 8),
            (512, 128, 311, 10000, True, torch.bfloat16, "cpu", 3, 39, 4, 2),
            (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 32, 8),
            (128, 128, 2048, 10000, False, torch.bfloat16, "cpu", 2, 512, 16, 4),
            (512, 128, 311, 10000, False, torch.bfloat16, "cpu", 3, 39, 4, 2),
        ]

        for (
            head_size,
            rotary_dim,
            max_position_embeddings,
            base,
            is_neox_style,
            dtype,
            device,
            batch_size,
            seq_len,
            num_q_heads,
            num_kv_heads,
        ) in test_config:
            single_test(
                head_size,
                rotary_dim,
                max_position_embeddings,
                base,
                is_neox_style,
                dtype,
                device,
                batch_size,
                seq_len,
                num_q_heads,
                num_kv_heads,
            )

176
177
178

if __name__ == "__main__":
    unittest.main()