test_concat_mla_q.py 4.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from vllm import _custom_ops as ops

NUM_TOKENS = [1, 4, 16, 64, 128]
NUM_HEADS = [128]
NOPE_DIM = [512]
ROPE_DIM = [64]
DTYPES = [torch.bfloat16, torch.float16]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("nope_dim", NOPE_DIM)
@pytest.mark.parametrize("rope_dim", ROPE_DIM)
@pytest.mark.parametrize("dtype", DTYPES)
def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, dtype):
    """Test with contiguous inputs (standard layout)."""
    torch.manual_seed(42)
    ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda")
    q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda")

    ref = torch.cat((ql_nope, q_pe), dim=-1)

    q_out = torch.empty(
        num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
    )
    ops.concat_mla_q(ql_nope, q_pe, q_out)

    torch.testing.assert_close(q_out, ref, atol=0, rtol=0)


@pytest.mark.parametrize("num_tokens", [t for t in NUM_TOKENS if t > 1])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("nope_dim", NOPE_DIM)
@pytest.mark.parametrize("rope_dim", ROPE_DIM)
@pytest.mark.parametrize("dtype", DTYPES)
def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, rope_dim, dtype):
    """Test with transposed nope input (simulates BMM output after transpose).

    In the real code path, mqa_ql_nope is the result of:
        torch.bmm(q_nope, W_UK_T)  # [N, B, L]
        .transpose(0, 1)            # [B, N, L] — non-contiguous!
    """
    torch.manual_seed(42)
    nope_raw = torch.randn(num_heads, num_tokens, nope_dim, dtype=dtype, device="cuda")
    ql_nope = nope_raw.transpose(0, 1)  # [B, N, L], non-contiguous
    assert not ql_nope.is_contiguous()

    q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda")

    ref = torch.cat((ql_nope, q_pe), dim=-1)

    q_out = torch.empty(
        num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
    )
    ops.concat_mla_q(ql_nope, q_pe, q_out)

    torch.testing.assert_close(q_out, ref, atol=0, rtol=0)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_concat_mla_q_split_rope(num_tokens, num_heads, dtype):
    """Test with rope from a split (simulates the actual code path).

    In the real code path, q_pe comes from:
        mqa_q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
    which creates a non-contiguous view with stride(1) != rope_dim.
    """
    torch.manual_seed(42)
    nope_dim = 512
    rope_dim = 64
    orig_dim = 128 + 64  # original q before absorption: [B, N, 192]

    # Simulate split from original q tensor
    q_orig = torch.randn(num_tokens, num_heads, orig_dim, dtype=dtype, device="cuda")
    q_nope_orig, q_pe = q_orig.split([128, 64], dim=-1)

    # q_pe is non-contiguous: stride(1) = 192, not 64
    assert q_pe.stride(1) == orig_dim
    assert q_pe.stride(2) == 1  # but innermost is fine

    # Simulate absorbed nope (contiguous, different size)
    ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda")

    ref = torch.cat((ql_nope, q_pe), dim=-1)

    q_out = torch.empty(
        num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda"
    )
    ops.concat_mla_q(ql_nope, q_pe, q_out)

    torch.testing.assert_close(q_out, ref, atol=0, rtol=0)


def test_concat_mla_q_zero_tokens():
    """Test with zero tokens (edge case)."""
    ql_nope = torch.empty(0, 128, 512, dtype=torch.bfloat16, device="cuda")
    q_pe = torch.empty(0, 128, 64, dtype=torch.bfloat16, device="cuda")
    q_out = torch.empty(0, 128, 576, dtype=torch.bfloat16, device="cuda")

    ops.concat_mla_q(ql_nope, q_pe, q_out)


@pytest.mark.parametrize("num_tokens", [1, 64])
def test_concat_mla_q_values_preserved(num_tokens):
    """Verify exact bit-level preservation (no computation, pure copy).

    Compares raw int16 bits to avoid NaN != NaN issues from IEEE 754.
    """
    nope_dim, rope_dim = 512, 64

    # Use specific bit patterns (stay in int16 for bit-exact comparison)
    ql_nope_bits = torch.arange(
        num_tokens * 128 * nope_dim, dtype=torch.int16, device="cuda"
    ).view(num_tokens, 128, nope_dim)
    q_pe_bits = torch.arange(
        num_tokens * 128 * rope_dim, dtype=torch.int16, device="cuda"
    ).view(num_tokens, 128, rope_dim)

    ql_nope = ql_nope_bits.view(torch.bfloat16)
    q_pe = q_pe_bits.view(torch.bfloat16)

    q_out = torch.empty(
        num_tokens, 128, nope_dim + rope_dim, dtype=torch.bfloat16, device="cuda"
    )
    ops.concat_mla_q(ql_nope, q_pe, q_out)

    out_bits = q_out.view(torch.int16)

    assert torch.equal(out_bits[..., :nope_dim], ql_nope_bits)

    assert torch.equal(out_bits[..., nope_dim:], q_pe_bits)