flash.h 2.46 KB
Newer Older
q.yao's avatar
q.yao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention

#pragma once

#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM     = 1;
constexpr int D_DIM     = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
    using index_t = uint32_t;
    // The QKV matrices.
    void* __restrict__ q_ptr;
    void* __restrict__ k_ptr;
    void* __restrict__ v_ptr;

    // batched ptr inputs.
    void** __restrict__ k_batched_ptr = nullptr;
    void** __restrict__ v_batched_ptr = nullptr;
    int k_batched_offset              = 0;
    int v_batched_offset              = 0;

    // The stride between rows of the Q, K and V matrices.
    index_t q_batch_stride;
    index_t k_batch_stride;
    index_t v_batch_stride;
    index_t q_row_stride;
    index_t k_row_stride;
    index_t v_row_stride;
    index_t q_head_stride;
    index_t k_head_stride;
    index_t v_head_stride;

    // The number of heads.
    int h, h_k;
    // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
    // different from nheads (query).
    int h_h_k_ratio;  // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params: public Qkv_params {

    // The O matrix (output).
    void* __restrict__ o_ptr;

    // The stride between rows of O.
    index_t o_batch_stride;
    index_t o_row_stride;
    index_t o_head_stride;

    // The pointer to the P matrix.
    void* __restrict__ p_ptr;

    // The dimensions.
    int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded;

    // The scaling factors for the kernel.
    float scale_softmax;
    float scale_softmax_log2;

    // array of length b+1 holding starting offset of each sequence.
    int* __restrict__ cu_seqlens_q;
    int* __restrict__ cu_seqlens_k;

    void* __restrict__ blockmask;

    bool is_bf16;
    bool is_causal;

    // enable output seqlen
    bool q_enable_seqlen;
    bool o_enable_seqlen;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);