flash_fwd_launch_template.h 8.14 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
/******************************************************************************
 * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
 ******************************************************************************/

#pragma once

#include "cute/tensor.hpp"

Tri Dao's avatar
Tri Dao committed
9
#include "cutlass/cutlass.h"
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
#include "cutlass/cluster_launch.hpp"

#include "static_switch.h"
#include "flash.h"
#include "tile_scheduler.hpp"
#include "flash_fwd_kernel.h"
#include "kernel_traits.h"
17
#include "seq_len.h"
18
#include "utils.h"
Tri Dao's avatar
Tri Dao committed
19
20


21
template<typename Kernel_traits, bool Is_causal, typename Seqlen_traits>
Tri Dao's avatar
Tri Dao committed
22
23
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
    using Element = typename Kernel_traits::Element;
ganeshcolfax's avatar
ganeshcolfax committed
24
    using ElementO = decltype(cute::conditional_return<is_same_v<Element, cutlass::float_e4m3_t>>(cutlass::half_t{}, Element{}));
Tri Dao's avatar
Tri Dao committed
25
26
27
28
    using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
    using ClusterShape = typename Kernel_traits::ClusterShape_MNK;

    // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    using CollectiveMainloop = flash::CollectiveMainloopFwd<Kernel_traits, Is_causal, Seqlen_traits>;
    using CollectiveEpilogue = flash::CollectiveEpilogueFwd<Kernel_traits, Seqlen_traits>;
    using Scheduler = std::conditional_t<
        Seqlen_traits::kUseVarSeqLen, 
        flash::SingleTileScheduler,
        std::conditional_t<!Is_causal,
            flash::StaticPersistentTileScheduler,
            flash::DynamicPersistentTileScheduler<Kernel_traits::kNThreads - cutlass::NumThreadsPerWarpGroup>
    >>;
    // using Scheduler = flash::SingleTileScheduler;
    Seqlen_traits seqlen_traits_q(
        params.total_q, params.seqlen_q, params.cu_seqlens_q);
    Seqlen_traits seqlen_traits_k(
        params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
Tri Dao's avatar
Tri Dao committed
43
44
45
    typename CollectiveMainloop::Params mainloop_params =
        CollectiveMainloop::to_underlying_arguments({
            static_cast<Element const*>(params.q_ptr),
46
47
48
49
            seqlen_traits_q.get_gmem_layout(
                params.seqlen_q, params.d, params.h, params.b, 
                params.q_row_stride, params.q_head_stride, params.q_batch_stride
            ),  // layout_Q
Tri Dao's avatar
Tri Dao committed
50
            static_cast<Element const*>(params.k_ptr),
51
52
53
54
            seqlen_traits_k.get_gmem_layout(
                params.seqlen_k, params.d, params.h_k, params.b, 
                params.k_row_stride, params.k_head_stride, params.k_batch_stride
            ),  // layout_K
Tri Dao's avatar
Tri Dao committed
55
            static_cast<Element const*>(params.v_ptr),
56
57
58
59
            seqlen_traits_k.get_gmem_layout(
                params.seqlen_k, params.d, params.h_k, params.b, 
                params.v_row_stride, params.v_head_stride, params.v_batch_stride
            ),  // layout_V
Tri Dao's avatar
Tri Dao committed
60
61
62
63
64
            params.scale_softmax_log2
        });
    typename CollectiveEpilogue::Params epilogue_params =
        CollectiveEpilogue::to_underlying_arguments({
            static_cast<Element*>(params.o_ptr),
65
66
67
68
            seqlen_traits_q.get_gmem_layout(
                params.seqlen_q, params.d, params.h, params.b,
                params.o_row_stride, params.o_head_stride, params.o_batch_stride
            ),  // layout_O
Tri Dao's avatar
Tri Dao committed
69
            static_cast<float*>(params.softmax_lse_ptr),
70
71
72
            seqlen_traits_q.get_lse_gmem_layout(
                params.seqlen_q, params.h, params.b
            )  // layout_LSE
Tri Dao's avatar
Tri Dao committed
73
74
75
76
        });

    int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
    num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
Tri Dao's avatar
Tri Dao committed
77
    typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore};
Tri Dao's avatar
Tri Dao committed
78
79
80
81
    typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);

    // Get the ptr to kernel function.
    void *kernel;
82
    kernel = (void *)flash::compute_attn_ws<Kernel_traits, Is_causal, Scheduler, Seqlen_traits>;
Tri Dao's avatar
Tri Dao committed
83
    int smem_size = sizeof(typename Kernel_traits::SharedStorage);
Tri Dao's avatar
Tri Dao committed
84
85
86
    // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
    // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
    // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
Tri Dao's avatar
Tri Dao committed
87
88
    // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
    if (smem_size >= 48 * 1024) {
89
       CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
Tri Dao's avatar
Tri Dao committed
90
91
    }

Tri Dao's avatar
Tri Dao committed
92
93
94
95
96
97
    int device;
    cudaGetDevice(&device);
    int multiprocessor_count;
    cudaError status_ = cudaDeviceGetAttribute(
        &multiprocessor_count, cudaDevAttrMultiProcessorCount, device);
    if (status_ != cudaSuccess) {
98
      CHECK_CUDA(status_);
Tri Dao's avatar
Tri Dao committed
99
100
    }
    dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
Tri Dao's avatar
Tri Dao committed
101
102
103
104
    static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
    dim3 block_dims(ctaSize);
    dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
    cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
105
106
107
    cutlass::launch_kernel_on_cluster(
        launch_params, kernel, mainloop_params, epilogue_params, 
        scheduler_params, seqlen_traits_q, seqlen_traits_k);
108
    CHECK_CUDA_KERNEL_LAUNCH();
Tri Dao's avatar
Tri Dao committed
109
110
111
112
113
114
}

template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 64;
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
115
116
117
118
119
120
        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
            run_flash_fwd<
                Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, 1, T>, 
                Is_causal, Seqlen_traits
            >(params, stream);
        });
Tri Dao's avatar
Tri Dao committed
121
122
123
124
125
126
127
    });
}

template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 128;
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
128
129
130
        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
            // Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
ganeshcolfax's avatar
ganeshcolfax committed
131
132
133
134
135
136
137
138
                if constexpr (is_same_v<T, cutlass::float_e4m3_t>) {
                    //run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 3, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
                    //run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 4, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
                    //run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 4, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
                } else {  
                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, Is_causal ? 128 : 176, 12, 2, false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
                }  
139
            });
Tri Dao's avatar
Tri Dao committed
140
        });
Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
    });
}

template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
    constexpr static int Headdim = 256;
    BOOL_SWITCH(params.is_causal, Is_causal, [&] {
148
149
150
        SEQLEN_SWITCH(params.cu_seqlens_q, Seqlen_traits, [&] {
            // Only use Cluster if number of tiles along seqlen_q is even
            BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0 && !Is_causal && !Seqlen_traits::kUseVarSeqLen, UseCluster, [&] {
ganeshcolfax's avatar
ganeshcolfax committed
151
152
153
154
155
                if constexpr (is_same_v<T, cutlass::float_e4m3_t>) {
                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 3, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream); 
                } else {
                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 80, 12, 2, false, UseCluster ? 2 : 1, T>, Is_causal, Seqlen_traits>(params, stream);
                }  
156
            });
Tri Dao's avatar
Tri Dao committed
157
        });
Tri Dao's avatar
Tri Dao committed
158
159
    });
}