llama4_rope.py 5.91 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch
import triton
import triton.language as tl


def _cast_and_contiguous(q, k, freqs_complex):
    # Align dtype: fp32 only when q is fp32; otherwise keep q dtype for perf
    compute_dtype = torch.float32 if q.dtype == torch.float32 else q.dtype

    if k.dtype != q.dtype:
        k = k.to(q.dtype)

    q = q.to(compute_dtype).contiguous()
    k = k.to(compute_dtype).contiguous()
    freqs_complex = freqs_complex.contiguous()
    return q, k, freqs_complex


@triton.jit
def _llama4_rope_kernel(
    q_ptr,
    k_ptr,
    freqs_complex_ptr,
    q_row_stride,
    k_row_stride,
    q_head_stride,
    k_head_stride,
    freqs_row_stride,
    seq_len,
    batch_size,
    imag_sign,
    head_dim_half: tl.constexpr,
    n_q_heads: tl.constexpr,
    n_k_heads: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
):
    """
    H100-optimized RoPE kernel with improved parallelization across heads and dimensions.
    Grid: (batch*seq, head)
    """
    # 2D grid
    pid_bs = tl.program_id(0)  # over batch*seq
    pid_h = tl.program_id(1)  # over heads

    batch_idx = pid_bs // seq_len
    seq_idx = pid_bs % seq_len

    # Bounds check
    if batch_idx >= batch_size or seq_idx >= seq_len:
        return

    # Base pointers for this (batch, seq) position
    base_offset = batch_idx * seq_len + seq_idx
    q_base = q_ptr + base_offset * q_row_stride
    k_base = k_ptr + base_offset * k_row_stride
    freq_base = seq_idx * freqs_row_stride

    # Tiling over dim/2
    for d_start in tl.static_range(0, head_dim_half, BLOCK_SIZE):
        d_indices = d_start + tl.arange(0, BLOCK_SIZE)
        mask_d = d_indices < head_dim_half

        # Compute offsets for the block
        freq_offsets = d_indices[:, None] * 2 + tl.arange(0, 2)[None, :]
        # Load the block
        freqs_complex = tl.load(freqs_complex_ptr + freq_base + freq_offsets, mask=mask_d[:, None], other=0.0)
        freqs_real, freqs_imag = tl.split(freqs_complex)
        freqs_imag = freqs_imag * imag_sign

        # Process one query head per program in pid_h
        if pid_h < n_q_heads:
            q_head_ptr = q_base + pid_h * q_head_stride
            q_real = tl.load(q_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
            q_imag = tl.load(q_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)

            # Complex multiply with FMAs: (a+ib)*(c+i d) = (a*c - b*d) + i(a*d + b*c)
            new_q_real = tl.math.fma(q_real, freqs_real, -(q_imag * freqs_imag))
            new_q_imag = tl.math.fma(q_real, freqs_imag, q_imag * freqs_real)

            tl.store(q_head_ptr + d_indices * 2, new_q_real, mask=mask_d)
            tl.store(q_head_ptr + d_indices * 2 + 1, new_q_imag, mask=mask_d)

        # Process one key head per program in pid_h
        if pid_h < n_k_heads:
            k_head_ptr = k_base + pid_h * k_head_stride
            k_real = tl.load(k_head_ptr + d_indices * 2, mask=mask_d, other=0.0)
            k_imag = tl.load(k_head_ptr + d_indices * 2 + 1, mask=mask_d, other=0.0)

            new_k_real = tl.math.fma(k_real, freqs_real, -(k_imag * freqs_imag))
            new_k_imag = tl.math.fma(k_real, freqs_imag, k_imag * freqs_real)

            tl.store(k_head_ptr + d_indices * 2, new_k_real, mask=mask_d)
            tl.store(k_head_ptr + d_indices * 2 + 1, new_k_imag, mask=mask_d)


def _select_kernel_meta(head_dim_half: int):
    # Heuristic tuning for block size and num_warps
    if head_dim_half >= 256:
        return 128, 8
    if head_dim_half >= 96:
        return 128, 4
    if head_dim_half >= 48:
        return 64, 4
    if head_dim_half >= 24:
        return 32, 2
    return 16, 2


def llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE: int = None, imag_sign: float = 1.0):
    # Save original dtype for casting back
    original_dtype = q.dtype

    batch_size, seq_len, n_q_heads, head_dim = q.shape
    _, _, n_k_heads, _ = k.shape
    head_dim_half = head_dim // 2
    if freqs_cis.is_complex():
        freqs_cis = freqs_cis.reshape(-1, freqs_cis.shape[-1])
        if freqs_cis.shape[0] > seq_len:
            freqs_cis = freqs_cis[:seq_len]
        freqs_cis = torch.view_as_real(freqs_cis)

    # Cast to appropriate dtype and make contiguous only when needed
    q, k, freqs_cis = _cast_and_contiguous(q, k, freqs_cis)

    # H100-optimized meta-params
    if BLOCK_SIZE is None:
        BLOCK_SIZE, num_warps = _select_kernel_meta(head_dim_half)
    else:
        # Provide a default num_warps if caller pins BLOCK_SIZE
        _, num_warps = _select_kernel_meta(head_dim_half)

    # 2D grid: one program per (batch, seq, head)
    n_heads_max = max(n_q_heads, n_k_heads)
    grid = (batch_size * seq_len, n_heads_max)

    # Launch kernel
    _llama4_rope_kernel[grid](
        q,
        k,
        freqs_cis,
        q.stride(1),
        k.stride(1),
        q.stride(2),
        k.stride(2),
        freqs_cis.stride(0),
        seq_len,
        batch_size,
        imag_sign,
        head_dim_half,
        n_q_heads,
        n_k_heads,
        BLOCK_SIZE,
        num_warps=num_warps,
        num_stages=2,
    )

    # Cast back to original dtype only if it differs from compute dtype
    if q.dtype != original_dtype:
        q = q.to(original_dtype)
    if k.dtype != original_dtype:
        k = k.to(original_dtype)

    return q, k


class LigerLlama4RopeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, freqs_cis, BLOCK_SIZE: int = None):
        q_out, k_out = llama4_rope_forward(q, k, freqs_cis, BLOCK_SIZE, imag_sign=1.0)
        ctx.save_for_backward(freqs_cis.detach() if isinstance(freqs_cis, torch.Tensor) else freqs_cis)
        ctx.BLOCK_SIZE = BLOCK_SIZE
        return q_out, k_out

    @staticmethod
    def backward(ctx, dq, dk):
        (freqs_cis,) = ctx.saved_tensors
        BLOCK_SIZE = getattr(ctx, "BLOCK_SIZE", None)
        # Use imag_sign=-1.0 for conjugate without materializing a new tensor
        dq_out, dk_out = llama4_rope_forward(dq, dk, freqs_cis, BLOCK_SIZE, imag_sign=-1.0)
        return dq_out, dk_out, None