flash.h 2.59 KB
Newer Older
q.yao's avatar
q.yao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention

#pragma once

#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM     = 1;
constexpr int D_DIM     = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
19
    using index_t = size_t;
q.yao's avatar
q.yao committed
20
21
22
23
24
25
26
27
    // The QKV matrices.
    void* __restrict__ q_ptr;
    void* __restrict__ k_ptr;
    void* __restrict__ v_ptr;

    // batched ptr inputs.
    void** __restrict__ k_batched_ptr = nullptr;
    void** __restrict__ v_batched_ptr = nullptr;
28
29
    size_t k_batched_offset           = 0;
    size_t v_batched_offset           = 0;
q.yao's avatar
q.yao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    // The stride between rows of the Q, K and V matrices.
    index_t q_batch_stride;
    index_t k_batch_stride;
    index_t v_batch_stride;
    index_t q_row_stride;
    index_t k_row_stride;
    index_t v_row_stride;
    index_t q_head_stride;
    index_t k_head_stride;
    index_t v_head_stride;

    // The number of heads.
    int h, h_k;
    // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
    // different from nheads (query).
    int h_h_k_ratio;  // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params: public Qkv_params {

    // The O matrix (output).
    void* __restrict__ o_ptr;

    // The stride between rows of O.
    index_t o_batch_stride;
    index_t o_row_stride;
    index_t o_head_stride;

    // The pointer to the P matrix.
    void* __restrict__ p_ptr;

    // The dimensions.
    int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded;

    // The scaling factors for the kernel.
    float scale_softmax;
    float scale_softmax_log2;

    // array of length b+1 holding starting offset of each sequence.
    int* __restrict__ cu_seqlens_q;
    int* __restrict__ cu_seqlens_k;

75
76
77
78
    // array of length b with actual length of each sequence
    int* __restrict__ actual_seqlen_q;
    int* __restrict__ actual_seqlen_k;

q.yao's avatar
q.yao committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    void* __restrict__ blockmask;

    bool is_bf16;
    bool is_causal;

    // enable output seqlen
    bool q_enable_seqlen;
    bool o_enable_seqlen;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);