test_rotary_embedding.py 3.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Optional, Tuple

import torch
from vllm.model_executor.layers.rotary_embedding import (
    RotaryEmbedding as VLLMRotaryEmbedding,
)


class SGLRotaryEmbedding(VLLMRotaryEmbedding):

    def forward_cuda(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        offsets: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        from sgl_kernel import rotary_embedding

        self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)

        rotary_embedding(
            positions,
            query,
            key,
            self.head_size,
            self.cos_sin_cache,
            self.is_neox_style,
        )
        return query, key


# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native


def test_rotary_embedding():
    # Test case 1: FP32
    def run_test(
        head_size,
        rotary_dim,
        max_position,
        base,
        is_neox_style,
        dtype,
        batch_size,
        seq_len,
        num_heads,
        test_name,
    ):
        print(f"\nRunning {test_name}...")
        # Initialize both implementations
        sgl_rope = SGLRotaryEmbedding(
            head_size, rotary_dim, max_position, base, is_neox_style, dtype
        ).to("cuda")
        vllm_rope = VLLMRotaryEmbedding(
            head_size, rotary_dim, max_position, base, is_neox_style, dtype
        ).to("cuda")

        # Regular forward pass
        positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
        query = torch.randn(
            batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
        )
        key = torch.randn(
            batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
        )

        # Make copies for both implementations
        query_sgl = query.clone()
        key_sgl = key.clone()
        query_vllm = query.clone()
        key_vllm = key.clone()

        # Run both implementations
        query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
            positions, query_sgl, key_sgl
        )
        query_vllm_out, key_vllm_out = vllm_rope.forward_native(
            positions, query_vllm, key_vllm
        )

        # Compare outputs
        torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
        torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)

        print(f"{test_name} passed!")

    # Test Case 1: FP32 with larger dimensions
    run_test(
        head_size=128,
        rotary_dim=64,
        max_position=4096,
        base=10000,
        is_neox_style=True,
        dtype=torch.float32,
        batch_size=4,
        seq_len=32,
        num_heads=8,
        test_name="FP32 Test",
    )

    # Test Case 2: BF16 with smaller dimensions
    run_test(
        head_size=64,
        rotary_dim=32,
        max_position=2048,
        base=8000,
        is_neox_style=True,
        dtype=torch.bfloat16,
        batch_size=2,
        seq_len=16,
        num_heads=4,
        test_name="BF16 Test",
    )


if __name__ == "__main__":
    test_rotary_embedding()