block_info.h 6.46 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/******************************************************************************
 * Copyright (c) 2023, Tri Dao.
 ******************************************************************************/

#pragma once

namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<bool Varlen=true, bool Is_Kvcache=false, bool USE_BSHD_LAYOUT = false>
struct BlockInfo {

    template<typename Params>
    __device__ BlockInfo(const Params &params, const int bidb)
        : sum_s_q((!Varlen || params.cu_seqlens_q == nullptr) ? -1 : params.cu_seqlens_q[bidb])
        , sum_s_k((!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative) ? -1 : params.cu_seqlens_k[bidb])
        , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr || Is_Kvcache ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
        // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
        // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
        , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
        , actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
        , nheads(params.h)
        , nheads_k(params.h_k)
        , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])
        {
        }

    template <typename index_t>
    __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
        return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride;
    }

    inline __device__  int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
        return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
    }

    inline __device__  int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
        return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
    }

    inline __device__  int k_offset1_write(const int batch_stride, const int row_stride, const int bidb) const {
        return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads);
    }

    inline __device__  int q_offset2(const int head_stride, const int bidh) const {
        return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
    }

    inline __device__  int k_offset2(const int head_stride, const int bidh) const {
        return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
    }

    const int sum_s_q;
    const int sum_s_k;
    const int actual_seqlen_q;
    // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
    const int leftpad_k;
    const int seqlen_k_cache;
    int actual_seqlen_k;
    const int nheads;
    const int nheads_k;

};


// Simplified blockinfo for tranditional varlen fwd inference
template<bool USE_BSHD_LAYOUT=false>
struct SimplifyBlockInfo {

    template<typename Params>
    __device__ SimplifyBlockInfo(const Params &params, const int bidb)
        : sum_s_q(params.cu_seqlens_q[bidb])
        , sum_s_k(params.cu_seqlens_k[bidb])
        , actual_seqlen_q(params.cu_seqlens_q[bidb + 1] - sum_s_q)
        // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
        // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
        , seqlen_k_cache((params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
        , actual_seqlen_k(seqlen_k_cache/* + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)*/)
        , nheads(params.h)
        , nheads_k(params.h_k)
        // , leftpad_k(0)
        {
        }

    inline __device__  int q_offset1(const int batch_stride, const int row_stride, const int bidb) const {
        return sum_s_q == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_q) * row_stride : uint32_t(sum_s_q) * row_stride * nheads);
    }

    inline __device__  int k_offset1(const int batch_stride, const int row_stride, const int bidb) const {
        return sum_s_k == -1 ? bidb * batch_stride : (USE_BSHD_LAYOUT ? uint32_t(sum_s_k) * row_stride : uint32_t(sum_s_k) * row_stride * nheads_k);
    }

    inline __device__  int q_offset2(const int head_stride, const int bidh) const {
        return (USE_BSHD_LAYOUT || sum_s_q == -1) ? bidh * head_stride : uint32_t(actual_seqlen_q) * head_stride * bidh;
    }

    inline __device__  int k_offset2(const int head_stride, const int bidh) const {
        return (USE_BSHD_LAYOUT || sum_s_k == -1) ? bidh * head_stride : uint32_t(actual_seqlen_k) * head_stride *bidh;
    }

    const int sum_s_q;
    const int sum_s_k;
    const int actual_seqlen_q;
    // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
    // const int leftpad_k;
    const int seqlen_k_cache;
    int actual_seqlen_k;
    const int nheads;
    const int nheads_k;

};


////////////////////////////////////////////////////////////////////////////////////////////////////

struct SafeDecodeBlockInfo {

    __device__ SafeDecodeBlockInfo() = default;

    template<typename Params, bool Is_Q_varlen, bool Is_K_Cumulative>
    __device__ void set_params(const Params &params, const int bidb) {
        // process Q
        if constexpr (Is_Q_varlen) { // Is_Q_varlen also means Is_Q_Cumulative = true
            this->sum_s_q = params.cu_seqlens_q[bidb];
            this->actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - this->sum_s_q;
        } else {
            this->actual_seqlen_q = params.seqlen_q;
        }
        // process KV
        if constexpr (Is_K_Cumulative) {
            this->sum_s_k = params.cu_seqlens_k[bidb];
            this->actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
        } else {
            this->actual_seqlen_k = params.cu_seqlens_k[bidb];
        }
    }

    int sum_s_q;
    int sum_s_k;
    int actual_seqlen_q;
    int actual_seqlen_k;
};

}  // namespace flash