fp8_lighting_indexer.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# ruff: noqa
import itertools
import tilelang
from tilelang import language as T
import torch
from utils import generate_random_cu_seqlens, per_custom_dims_cast_to_fp8


def display_error_message(msg):
    print(f"\033[31mWARNING: {msg}\033[0m")


def compute_correlation(a, b, label="tensor"):
    a, b = a.data.double(), b.data.double()
    norm_sum = (a * a + b * b).sum()
    if norm_sum == 0:
        display_error_message(f"{label} all zero")
        return 1
    correlation = 2 * (a * b).sum() / norm_sum
    return correlation


def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_raise=True):
    a_finite = torch.isfinite(a)
    b_finite = torch.isfinite(b)
    if not torch.all(a_finite == b_finite):
        display_error_message(f"{tensor_name} Error: isfinite mask mismatch")
        if should_raise:
            assert False
    if not torch.isclose(
31
32
33
34
35
        a.masked_fill(a_finite, 0),
        b.masked_fill(b_finite, 0),
        rtol=0,
        atol=0,
        equal_nan=True,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    ).all():
        display_error_message(f"{tensor_name} Error: nonfinite value mismatch")
        if should_raise:
            assert False
    a = a.masked_fill(~a_finite, 0)
    b = b.masked_fill(~b_finite, 0)
    correlation = compute_correlation(a, b, tensor_name)
    difference = 1.0 - correlation
    if not (0 <= difference <= tolerance):
        display_error_message(f"{tensor_name} Error: {difference}")
        if should_raise:
            assert False
    return difference


def get_configs():
    iter_params = dict(
        block_N=[32, 64, 128],
        num_stages=[0, 1, 2],
        threads=[128, 256],
        block_Q=[1, 2, 4],
    )
58
    return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87


class SupplyProg:
    def __init__(self):
        self.tensors_dict = {}

    def get_key(self, shape, dtype) -> str:
        return f"{shape}-{dtype}"

    def supply_prog(self, params):
        shapes = [p.shape for p in params]
        dtypes = [p.dtype for p in params]
        tensor_list = []
        for shape, dtype in zip(shapes, dtypes):
            key = self.get_key(shape, dtype)
            if key not in self.tensors_dict:
                self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda")
                tensor_list.append(self.tensors_dict[key])
            else:
                tensor_list.append(self.tensors_dict[key])
        return tensor_list


supply_prog = SupplyProg()


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
88
89
    },
)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
def mqa_attn_return_logits(
    heads,
    index_dim,
    block_N=256,
    num_stages=3,
    threads=512,
    block_Q=None,
):
    if block_Q is None:
        block_Q = 128 // heads
    dtype = "float8_e4m3"
    accum_dtype = "float"
    index_dtype = "int32"

104
105
    seq_len = T.dynamic("seq_len")
    seq_len_kv = T.dynamic("seq_len_kv")
106
107
108
109
110
111
112
113

    index_q_shape = [seq_len * heads, index_dim]
    index_k_shape = [seq_len_kv, index_dim]
    index_k_scale_shape = [seq_len_kv]
    logits_shape = [seq_len, seq_len_kv]

    @T.prim_func
    def mqa_attn_return_logits_kernel(
114
115
116
117
118
119
120
        IndexQ: T.Tensor(index_q_shape, dtype),  # type: ignore
        IndexK: T.Tensor(index_k_shape, dtype),  # type: ignore
        IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype),  # type: ignore
        Logits: T.Tensor(logits_shape, accum_dtype),  # type: ignore
        Weights: T.Tensor([seq_len, heads], accum_dtype),  # type: ignore
        CuSeqLenKS: T.Tensor([seq_len], index_dtype),  # type: ignore
        CuSeqLenKE: T.Tensor([seq_len], index_dtype),  # type: ignore
121
122
123
124
125
126
    ):
        with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
            index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype)
            index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
            index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
            s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
127
            s_reshaped = T.reshape(s, (block_N, block_Q, heads))
128
129
130
131
132
133
134
135
136
137
138
139
            logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
            weights = T.alloc_fragment([block_Q, heads], accum_dtype)

            seq_len_i = bx * block_Q

            cu_k_s_min = T.alloc_local([1], index_dtype)
            cu_k_e_max = T.alloc_local([1], index_dtype)

            cu_k_s_min[0] = 2147483647
            cu_k_e_max[0] = -2147483648

            for bq_i in T.serial(block_Q):
140
                cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
141
            for bq_i in T.serial(block_Q):
142
                cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
143
144
145
146

            T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
            T.copy(Weights[seq_len_i, 0], weights)

147
            for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
148
149
150
151
152
153
154
155
156
157
158
159
160
                T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
                T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)

                T.gemm(
                    index_k_shared,
                    index_q_shared,
                    s,
                    transpose_B=True,
                    clear_accum=True,
                    policy=T.GemmWarpPolicy.FullCol,
                )

                for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
161
162
163
                    s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[
                        bn_i
                    ]
164
165
166
167

                T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)

                for bq_i, bn_i in T.Parallel(block_Q, block_N):
168
                    Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i]
169
170
171
172
173
174
175
176
177

    return mqa_attn_return_logits_kernel


@tilelang.jit
def clean_logits_(
    threads: int = 512,
    block_K: int = 4096,
):
178
179
    seq_len = T.dynamic("seq_len")
    seq_len_kv = T.dynamic("seq_len_kv")
180
181
182
183
184
185

    dtype = "float"
    indices_dtype = "int32"

    @T.prim_func
    def clean_logits_kernel(
186
187
188
        Logits: T.Tensor([seq_len, seq_len_kv], dtype),  # type: ignore
        CuSeqLenKS: T.Tensor([seq_len], indices_dtype),  # type: ignore
        CuSeqLenKE: T.Tensor([seq_len], indices_dtype),  # type: ignore
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    ):
        with T.Kernel(seq_len, threads=threads) as bx:
            tx = T.thread_binding(0, threads, thread="threadIdx.x")
            cu_k_s = T.alloc_local([1], indices_dtype)
            cu_k_e = T.alloc_local([1], indices_dtype)
            cu_k_s[0] = CuSeqLenKS[bx]
            cu_k_e[0] = CuSeqLenKE[bx]

            for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
                for k_i in T.serial(block_K // threads):
                    idx = n_i * block_K + k_i * threads + tx
                    if idx < cu_k_s[0] or idx >= cu_k_e[0]:
                        Logits[bx, idx] = -T.infinity(dtype)

    return clean_logits_kernel


206
def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True):
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    seq_len, heads, index_dim = q.shape
    seq_len_kv = kv.shape[0]

    clean_logits_kernel = clean_logits_()

    mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
    logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32)
    mqa_attn_return_logits_kernel(
        q.view(seq_len * heads, index_dim),
        kv,
        kv_scales,
        logits,
        weights,
        cu_seqlen_ks,
        cu_seqlen_ke,
    )
    if clean_logits:
        clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke)
    return logits


228
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor):
229
230
231
232
233
    k = kv
    q = q.float()
    k = k.float()

    seq_len_kv = kv.shape[0]
234
235
    mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
    mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
236
237
    mask = mask_lo & mask_hi

238
    score = torch.einsum("mhd,nd->hmn", q, k)
239
    logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
240
    logits = logits.masked_fill(~mask, float("-inf"))
241
242
243
244

    cost = mask.sum()
    return logits, cost

245

246
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
247
248
    # initial random seed to make the performance reproducible
    torch.manual_seed(0)
249
250
251
252
253
    q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
    kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
    weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
    p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)

254
    ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
255

256
    logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
257
258
259
260

    q_fp8 = q.to(torch.float8_e4m3fn)
    kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False)

261
262
    logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
    diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False)
263
264
265
266
267
268

    print(f"diff: {diff}")

    from tilelang.profiler import do_bench

    def logits_fn():
269
        return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
270
271
272
273
274
275
276
277
278
279
280

    with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
        logits_fn()

    print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50))

    logits_ms = do_bench(logits_fn, warmup=100, rep=100)
    logits_flops = 2 * cost_ref * H * D
    logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
    print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
    print(f"cost_ref: {cost_ref}")
281

282

283
284
if __name__ == "__main__":
    test_fp8_lighting_indexer()