test_chunk_gated_vllm.py 3.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# SPDX-License-Identifier: MIT

import pytest
import torch

from aiter.ops.triton.fla.vllm.chunk_delta_h import (
    chunk_gated_delta_rule_fwd_h,
    prepare_chunk_indices,
)
from aiter.ops.triton.fla.vllm.chunk_o import chunk_fwd_o
from op_tests.triton_tests.utils.chunk_delta_h_vllm_ref import (
    chunk_gated_delta_rule_fwd_h as chunk_gated_delta_rule_fwd_h_ref,
)
from op_tests.triton_tests.utils.chunk_o_vllm_ref import (
    chunk_fwd_o as chunk_fwd_o_ref,
)


@pytest.mark.parametrize("use_varlen", [False, True])
@pytest.mark.parametrize("state_index_mode", ["none", "identity", "random"])
def test_chunk_gated_fwd_h_and_o_against_ref(use_varlen: bool, state_index_mode: str):
    if not torch.cuda.is_available():
        pytest.skip("CUDA is required")

    torch.manual_seed(42)
    device = torch.device("cuda")

    bsz = 1
    total_t = 130
    h_dim = 16
    hg = 8
    k_dim = 128
    v_dim = 128
    chunk_size = 64

    q = torch.randn((bsz, total_t, hg, k_dim), device=device, dtype=torch.float16) * 0.2
    k = torch.randn((bsz, total_t, hg, k_dim), device=device, dtype=torch.float16) * 0.2
    w = torch.randn((bsz, total_t, h_dim, k_dim), device=device, dtype=torch.float16) * 0.2
    u = torch.randn((bsz, total_t, h_dim, v_dim), device=device, dtype=torch.float16) * 0.2
    g = torch.randn((bsz, total_t, h_dim), device=device, dtype=torch.float32) * 0.05

    if use_varlen:
        cu_seqlens = torch.tensor([0, 37, 93, total_t], device=device, dtype=torch.long)
        n_seq = len(cu_seqlens) - 1
    else:
        cu_seqlens = None
        n_seq = bsz

    if state_index_mode == "none":
        state_rows = n_seq
        initial_state_indices_cur = None
        initial_state_indices_ref = torch.arange(n_seq, device=device, dtype=torch.int32)
    elif state_index_mode == "random":
        state_rows = n_seq + 3
        initial_state_indices_cur = torch.randperm(state_rows, device=device, dtype=torch.int64)[:n_seq].to(torch.int32)
        initial_state_indices_ref = initial_state_indices_cur
    else:
        state_rows = n_seq
        initial_state_indices_cur = torch.arange(n_seq, device=device, dtype=torch.int32)
        initial_state_indices_ref = initial_state_indices_cur

    initial_state_cur = torch.randn((state_rows, h_dim, v_dim, k_dim), device=device, dtype=torch.float32) * 0.02
    initial_state_ref = initial_state_cur.clone()

    chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None

    h_cur, v_new_cur, _ = chunk_gated_delta_rule_fwd_h(
        k=k,
        w=w,
        u=u,
        g=g,
        initial_state=initial_state_cur,
        initial_state_indices=initial_state_indices_cur,
        output_final_state=True,
        chunk_size=chunk_size,
        save_new_value=True,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        use_exp2=False,
        transpose_state_layout=True,
    )

    h_ref, v_new_ref = chunk_gated_delta_rule_fwd_h_ref(
        k=k,
        w=w,
        u=u,
        g=g,
        gk=None,
        initial_state=initial_state_ref,
        initial_state_indices=initial_state_indices_ref,
        save_new_value=True,
        cu_seqlens=cu_seqlens,
    )

    scale = k_dim ** -0.5
    o_cur = chunk_fwd_o(
        q=q,
        k=k,
        v=v_new_cur,
        h=h_cur,
        g=g,
        scale=scale,
        cu_seqlens=cu_seqlens,
        chunk_size=chunk_size,
        chunk_indices=chunk_indices,
        use_exp2=False,
        transpose_state_layout=True,
    )

    o_ref = chunk_fwd_o_ref(
        q=q,
        k=k,
        v=v_new_ref,
        h=h_ref,
        g=g,
        scale=scale,
        cu_seqlens=cu_seqlens,
        chunk_size=chunk_size,
    )

    assert torch.equal(h_cur, h_ref)
    assert torch.equal(o_cur, o_ref)