test_qwen2vl_mrope.py 5.05 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import pytest
import torch

from test.utils import supports_bfloat16

try:
    from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLTextConfig
    from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLRotaryEmbedding
    from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb

    IS_QWEN_AVAILABLE = True
except Exception:
    IS_QWEN_AVAILABLE = False

from liger_kernel.ops import LigerQwen2VLMRopeFunction
from liger_kernel.transformers.functional import liger_qwen2vl_mrope
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
from liger_kernel.utils import infer_device

device = infer_device()


@pytest.mark.skipif(not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers.")
@pytest.mark.parametrize("bsz", [1, 2])
@pytest.mark.parametrize("seq_len", [128, 131])
@pytest.mark.parametrize("num_q_heads, num_kv_heads", [(64, 8), (28, 4), (12, 2)])
@pytest.mark.parametrize(
    "head_dim, mrope_section",
    [
        (128, [16, 24, 24]),
        (96, [16, 16, 16]),
        (64, [8, 12, 12]),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.float32, 1e-5, 1e-5),
        pytest.param(
            torch.bfloat16,
            1e-1,
            1e-5,
            marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
        ),
    ],
)
def test_correctness(bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol):
    rotary_emb = Qwen2VLRotaryEmbedding(config=Qwen2VLTextConfig(head_dim=head_dim), device=device)

    _tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype)

    _tensor_k = torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device).transpose(1, 2).to(dtype)

    q1 = _tensor_q.clone().requires_grad_(True)
    k1 = _tensor_k.clone().requires_grad_(True)

    q2 = _tensor_q.clone().requires_grad_(True)
    k2 = _tensor_k.clone().requires_grad_(True)

    # NOTE: this position ids distribution is different from the real one, just to test op correctness
    pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(3, bsz, seq_len)
    cos, sin = rotary_emb(k1, pos_ids)

    # validate forward pass
    hf_q, hf_k = apply_multimodal_rotary_pos_emb(q1, k1, cos, sin, mrope_section)
    tt_q, tt_k = liger_multimodal_rotary_pos_emb(q2, k2, cos, sin, mrope_section)
    torch.testing.assert_close(hf_q, tt_q, atol=atol, rtol=rtol)
    torch.testing.assert_close(hf_k, tt_k, atol=atol, rtol=rtol)

    # validate backward pass
    dq, dk = (
        torch.randn_like(hf_q, device=device),
        torch.randn_like(hf_k, device=device).to(dtype),
    )

    q1_grad, k1_grad = torch.autograd.grad((hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True)
    q2_grad, k2_grad = torch.autograd.grad((tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True)

    torch.testing.assert_close(q1_grad, q2_grad, atol=atol, rtol=rtol)
    torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol)


@pytest.mark.skipif(not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers.")
@pytest.mark.parametrize(
    "bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section",
    [
        (1, 2, 2, 2, 8, [2, 1, 1]),
        (1, 2, 1, 2, 8, [2, 1, 1]),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.float32, 1e-5, 1e-5),
        (torch.bfloat16, 1e-1, 1e-5),
    ],
)
def test_functional_correctness(bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section, dtype, atol, rtol):
    _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype)
    _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype)

    q1 = _q.clone().requires_grad_(True)
    q2 = _q.clone().requires_grad_(True)

    k1 = _k.clone().requires_grad_(True)
    k2 = _k.clone().requires_grad_(True)

    rotary_emb = Qwen2VLRotaryEmbedding(config=Qwen2VLTextConfig(head_dim=head_dim), device=device)

    pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(3, bsz, seq_len)
    cos, sin = rotary_emb(k1, pos_ids)

    functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section)
    class_q, class_k = LigerQwen2VLMRopeFunction.apply(q2, k2, cos, sin, mrope_section)

    torch.testing.assert_close(functional_q, class_q, atol=atol, rtol=rtol)
    torch.testing.assert_close(functional_k, class_k, atol=atol, rtol=rtol)

    dq, dk = torch.randn_like(functional_q), torch.randn_like(functional_k)

    dq1, dk1 = dq.clone(), dk.clone()
    dq2, dk2 = dq.clone(), dk.clone()

    q1_grad, k1_grad = torch.autograd.grad(
        (functional_q, functional_k),
        (q1, k1),
        (dq1, dk1),
        allow_unused=True,
    )

    q2_grad, k2_grad = torch.autograd.grad(
        (class_q, class_k),
        (q2, k2),
        (dq2, dk2),
        allow_unused=True,
    )

    torch.testing.assert_close(q1_grad, q2_grad, atol=atol, rtol=rtol)
    torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol)