fmha_bwd_launch_template.h 6.3 KB
Newer Older
1
2
3
// Copyright (c) 2022, Tri Dao.

#pragma once
Tri Dao's avatar
Tri Dao committed
4

Tri Dao's avatar
Tri Dao committed
5
#include "static_switch.h"
Tri Dao's avatar
Tri Dao committed
6
7
8
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"

9
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
Tri Dao's avatar
Tri Dao committed
10
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
11
// dq_tmp and having to copy dq_tmp to dq.
12
inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
13
14
15
16
17
18
19
20
21
22
23
24
25
                             int blocksize, bool is_causal) {
    float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm);
    float eff_1 = n_waves_1 / ceil(n_waves_1);
    int num_splits_parallel = seqlen / blocksize;
    float n_waves_parallel = float(batch_nheads * num_splits_parallel) / (num_SMs * ctas_per_sm);
    float eff_parallel_raw = n_waves_parallel / ceil(n_waves_parallel);
    float discount_factor;
    if (!is_causal) {
        discount_factor = 1.f + float(blocksize) / seqlen;
    } else {  // For causal, parallelizing seems to help with load-balancing as well
        // For example, if headdim=128, seqlen >= 1280 always prefers parallel
        if (seqlen / blocksize >= 10) return num_splits_parallel;
        discount_factor = 1.f + 0.5 * float(blocksize) / seqlen;
Tri Dao's avatar
Tri Dao committed
26
    }
27
28
29
30
31
    float eff_parallel = eff_parallel_raw / discount_factor;
    return eff_1 >= eff_parallel ? 1 : num_splits_parallel;
}

template<typename Kernel_traits>
32
__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) {
33
    fmha::compute_dot_do_o<Kernel_traits>(params);
Tri Dao's avatar
Tri Dao committed
34
35
}

Tri Dao's avatar
Tri Dao committed
36
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
37
__global__ void fmha_bwd_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
Tri Dao's avatar
Tri Dao committed
38
39
40
    fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}

41
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
42
__global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
43
44
45
    fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
}

Tri Dao's avatar
Tri Dao committed
46
template<typename Kernel_traits>
47
void run_fmha_bwd_loop(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
Tri Dao's avatar
Tri Dao committed
48
49
50
51
52
53
54
55
56
57
    constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
    constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
    constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
    constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;

    using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
    constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
    static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
    static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);

58
    constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
Tri Dao's avatar
Tri Dao committed
59
60
    constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
    // printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
Tri Dao's avatar
Tri Dao committed
61
62

    bool is_dropout = params.p_dropout < 1.f;  // params.p_dropout is the probability of "keeping"
63
    // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
64
    BOOL_SWITCH(is_dropout, IsDropoutConst, ({
65
        auto kernel = params.is_causal
66
67
            ? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
            : &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
68
69
        if (params.seqlen_k == blocksize_c) {
            kernel = params.is_causal
70
71
                ? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
                : &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
72
73
        } else if (params.seqlen_k == blocksize_c * 2) {
            kernel = params.is_causal
74
75
                ? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
                : &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
76
        }
77
        auto kernel_seqparallel = params.is_causal
78
79
            ? &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
            : &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
80
81
82
        if( smem_size_dq_dk_dv >= 48 * 1024 ) {
            FMHA_CHECK_CUDA(cudaFuncSetAttribute(
                kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
83
84
            FMHA_CHECK_CUDA(cudaFuncSetAttribute(
                kernel_seqparallel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
85
        }
Tri Dao's avatar
Tri Dao committed
86
87
88
89
90
91
92
93
94
95
96
97
        // Automatically set num_splits to maximize occupancy
        if (params.num_splits <= 0) {
            int ctas_per_sm;
            cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
                &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size_dq_dk_dv);
            auto dprops = at::cuda::getCurrentDeviceProperties();
            // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
            constexpr int M = Kernel_traits::Cta_tile_p::M;
            // We don't want more than 10 splits due to numerical error.
            // Numerical error on dk/dv scales as sqrt(num_splits).
            params.num_splits = num_splits_heuristic_bwd(
                params.b * params.h, dprops->multiProcessorCount,
98
                ctas_per_sm, params.seqlen_k, blocksize_c, params.is_causal
Tri Dao's avatar
Tri Dao committed
99
100
101
            );
        }
        if (configure) return;
102
103
104
105
106
        if (params.num_splits == 1) {
            dim3 grid(params.b, params.h, params.num_splits);
            kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
        } else {
            dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
107
            fmha_bwd_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
108
109
110
111
            int num_splits = params.seqlen_k / blocksize_c;  // seqlen_k is divisible by blocksize_c
            dim3 grid(params.b, params.h, num_splits);
            kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
        }
112
        FMHA_CHECK_CUDA(cudaPeekAtLastError());
113
    }));
Tri Dao's avatar
Tri Dao committed
114
}