fmha_kernel.h 7.04 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#pragma once

#include <philox.cuh>

#include <fmha.h>
#include <fmha/utils.h>
#include <fmha/smem_tile.h>
#include <fmha/gmem_tile.h>
#include <fmha/mask.h>
#include <fmha/softmax.h>

namespace fmha {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int THREADS_PER_CTA>
struct BlockInfoPadded {

    template<typename Params>
    __device__ BlockInfoPadded(const Params &params,
                               const int bidb,
                               const int bidh,
                               const int tidx)
        : bidb(bidb), bidh(bidh), h(params.h) {

        // The block index.
        sum_s = params.cu_seqlens[bidb];
        actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s;
        bidx = sum_s * params.h + bidh;

        tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
    }

    __device__ bool stop_early(const int start_col = 0) const {
        return actual_seqlen <= start_col;
    }

    int actual_seqlen;
    int bidx;
    int sum_s;
    int bidh;
    int bidb;
    int tidx_global;
    int h;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int CHUNKS, typename Cta_tile> 
struct Noloop_traits{
    // Interpretation of Cta_tile dims, i.e. Cta_tile_p:
    enum{ STEP = Cta_tile::M };
    enum{ SEQLEN = Cta_tile::N };

    template<typename Block_info>
    inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) 
        : bidc_(bidc) {
        const int seqlen = binfo.actual_seqlen;
        const int steps = (seqlen  + STEP - 1) / STEP;
        const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS;

        const int step_begin = bidc_ * steps_per_chunk; 
        const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk);
        const int actual_steps = max(0, step_end - step_begin);
        loop_offset_ = step_begin;
        num_steps_ = actual_steps;

    }

    template<typename ... Tiles> 
    inline __device__ void move_all(Tiles & ... tiles) const {
        using expand_type = int[];
        for( int s = 0; s < loop_offset_; s++ ) {
            expand_type{ (tiles.move(), 0)... };
        }
    }

    inline __device__ int get_idx_dk() const {
        //return bidc_;
        return bidc_ * 2 + 0;
    }

    inline __device__ int get_idx_dv() const {
        //return CHUNKS + bidc_;
        return bidc_ * 2 + 1;
    }

    inline __device__ int offset_loop_count(const int l) {
        // convert loop counter to position in the outer sequence
        return (loop_offset_ + l) * STEP;
    }

    const uint32_t bidc_;
    int loop_offset_;
    int num_steps_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits>
std::tuple<int , int, int, int, int, int> work_dist(const int total_ctas, const int heads_total) {

    constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M;

    const int num_full_heads = heads_total / total_ctas;
    const int heads_last_wave = heads_total % total_ctas; 

    int num_main_groups = 0;
    int main_steps = 0;
    int rest_steps = 0;
    if( heads_last_wave > 0 ) {
        // Number of CTA groups that process within heads.
        num_main_groups = total_ctas / heads_last_wave;
        // Remaining CTAs that process between heads.
        const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups);
        if(rest_ctas == 0) {
            // We have exactly "num_main_groups" CTAs to process each of the remaining heads.
            main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups;
            num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0
            rest_steps = STEPS_PER_HEAD % main_steps;

        } else {
            // Ideal number of steps if we could load-balance as evenly as possible.
            const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas;
            // Iterations that a "rest" CTA has to do at most.
            const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas;
            // Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs.
            main_steps = steps_ideal;
            rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
            for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) {
                rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
                const int max_rest_total_steps = rest_steps * max_rest_iters;
                if( max_rest_total_steps < main_steps )
                    break;
            }
            rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups;
        }
    }

    using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
    using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;

    const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps);
    const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8;
    const int elts_per_thread = max_steps * elts_per_thread_per_step;

    return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread};
}

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace fmha