attention.cu 4.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
#include "zgemm.h"
#include "attention.cuh"

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif

namespace nunchaku::kernels {

Muyang Li's avatar
Muyang Li committed
10
11
12
13
14
15
16
17
18
void attention_fp16(Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
                    Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
                    Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
                    Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
                    float scale) {
    int sizeBatch   = q.shape[0];
    int numHeads    = q.shape[1];
    int numTokensQ  = q.shape[2];
    int headDim     = q.shape[3];
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    int numTokensKV = k.shape[2];

    assert(o.ndims() == 3);
    assert(o.shape[0] == sizeBatch);
    assert(o.shape[1] == numTokensQ);
    assert(o.shape[2] == numHeads * headDim);

    spdlog::trace("attention_fp16: B={} H={} NQ={} NK={}", sizeBatch, numHeads, numTokensQ, numTokensKV);
    spdlog::trace("q at {}", q.data_ptr());
    spdlog::trace("k at {}", k.data_ptr());
    spdlog::trace("v at {}", v.data_ptr());
    spdlog::trace("o at {}", o.data_ptr());
    spdlog::trace("scale={}", scale);

    dispatchBool(o.scalar_type() == Tensor::BF16, [&]<bool bf16out>() {
#ifndef __INTELLISENSE__
        using Attention = typename nunchaku::kernels::Attention<AttentionFP16Config<bf16out>>;
#else
        using Attention = typename nunchaku::kernels::Attention<AttentionFP16Config<true>>;
#endif
        using GEMM = typename Attention::GEMM;

        assert(isTypeMatch<typename Attention::half_t>(q.scalar_type()));
        assert(isTypeMatch<typename Attention::half_t>(k.scalar_type()));
        assert(isTypeMatch<typename Attention::half_t>(v.scalar_type()));
        assert(isTypeMatch<typename Attention::epilogue_half_t>(o.scalar_type()));

        int shmem = 0;

        // we use exp2 instead of exp in the kernel
        scale *= M_LOG2E;

        assert(numTokensQ % Attention::BLOCK_M == 0);
        assert(numTokensKV % Attention::WARP_K == 0);
        assert(headDim == Attention::HEAD_DIM);

        auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
            dim3 grid(numTokensQ / Attention::BLOCK_M, numHeads, sizeBatch);
            using packed_q_t = typename Attention::packed_q_t;
            using packed_k_t = typename Attention::packed_k_t;
            using packed_v_t = typename Attention::packed_v_t;

Muyang Li's avatar
Muyang Li committed
61
62
63
64
65
66
67
68
69
            auto func = invoke_kernel<typename Attention::attention_fp16_kernel<Epilogue>,
                                      const packed_q_t *,
                                      const packed_k_t *,
                                      const packed_v_t *,
                                      float,
                                      int,
                                      int,
                                      typename Epilogue::Arguments,
                                      bool>;
70
71
72
73

            shmem = std::max(shmem, Attention::template attention_fp16_kernel<Epilogue>::SHMEM_SIZE);

            if (shmem >= 24 * 1024) {
fengzch-das's avatar
fengzch-das committed
74
                checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
75
76
            }

fengzch-das's avatar
fengzch-das committed
77
            func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(q.data_ptr<packed_q_t>(),
Muyang Li's avatar
Muyang Li committed
78
79
80
81
82
83
84
                                                                                             k.data_ptr<packed_k_t>(),
                                                                                             v.data_ptr<packed_v_t>(),
                                                                                             scale,
                                                                                             numTokensQ,
                                                                                             numTokensKV,
                                                                                             args,
                                                                                             false);
fengzch-das's avatar
fengzch-das committed
85
            checkCUDA(cudaGetLastError());
86
87
88
        };

        launch.template operator()<typename GEMM::EpilogueDefault>(typename GEMM::EpilogueDefault::Arguments{
Muyang Li's avatar
Muyang Li committed
89
            .out     = o.data_ptr<typename GEMM::half_t>(),
90
91
92
93
94
95
            .actualM = sizeBatch * numTokensQ,
            .actualN = numHeads * headDim,
        });
    });
}

Muyang Li's avatar
Muyang Li committed
96
}; // namespace nunchaku::kernels