/****************************************************************************** * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ #include #include #include "static_switch.h" #include "fp16_switch.h" #include "fmha.h" #include "fmha_fprop_kernel_1xN.h" // Find the number of splits that maximizes the occupancy. For example, if we have // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // better than having 3 splits (efficiency = 0.67). However, we also don't want too many // splits as that would incur more HBM reads/writes. // So we find the best efficiency, then find the smallest number of splits that gets 95% // of the best efficiency. int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) { float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); for (int num_splits = 1; num_splits <= max_splits; num_splits++) { float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm); float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } efficiency.push_back(eff); } for (int num_splits = 1; num_splits <= max_splits; num_splits++) { if (efficiency[num_splits - 1] > 0.95 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); return num_splits; } } return 1; } template __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { fmha::device_1xN_loop(params); } template void run_fmha_fp16_sm80_loop_(Launch_params &launch_params, const bool configure) { constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; if (configure) { using Mma_tile_p = fmha::Hmma_tile; constexpr int M = Kernel_traits::Cta_tile_p::M; size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M; constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps; launch_params.elts_per_thread = elts_per_head; return; } constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; // Don't need smem_size_softmax_lse if we're not looping const int smem_size = fmha::get_dynamic_smem_size() + (loop_steps > 1 ? smem_size_softmax_lse : 0); // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // https://github.com/kokkos/kokkos-kernels/issues/349 // https://github.com/HazyResearch/flash-attention/issues/21 BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] { auto kernel = launch_params.params.is_causal ? (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel : &fmha_fprop_fp16_sm80_loop_kernel) : (launch_params.return_softmax ? &fmha_fprop_fp16_sm80_loop_kernel : &fmha_fprop_fp16_sm80_loop_kernel); if( smem_size >= 48 * 1024 ) { FMHA_CHECK_CUDA(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); } // Automatically set num_splits to maximize occupancy if (launch_params.params.num_splits <= 0) { int ctas_per_sm; cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size); auto dprops = at::cuda::getCurrentDeviceProperties(); // printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount); constexpr int M = Kernel_traits::Cta_tile_p::M; launch_params.params.num_splits = num_splits_heuristic_fwd( launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount, ctas_per_sm, /*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M)) ); } dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits); kernel<<>>( launch_params.params); FMHA_CHECK_CUDA(cudaPeekAtLastError()); }); } void run_fmha_fp16_sm80(Launch_params &launch_params, const bool configure) { FP16_SWITCH(launch_params.params.is_bf16, [&] { auto dprops = at::cuda::getCurrentDeviceProperties(); if (launch_params.params.d == 16) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else if( launch_params.params.seqlen_k == 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else { // TD [2022-05-15] 512 gives wrong results rn // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } } else if (launch_params.params.d == 32) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else if( launch_params.params.seqlen_k >= 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } } else if (launch_params.params.d == 64) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else if( launch_params.params.seqlen_k >= 256 ) { using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } } else if (launch_params.params.d == 128) { if( launch_params.params.seqlen_k == 128 ) { using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else { if (dprops->major == 8 && dprops->minor == 0) { // TD [2022-06-05] Keep K in registers to reduce register spilling // Gives about 6% speedup compared to using block size 128. using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } else { // Need to use the same block size as backward using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; run_fmha_fp16_sm80_loop_(launch_params, configure); } } } // if (launch_params.params.d == 64) { // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } // if (launch_params.params.d == 64) { // if( launch_params.params.seqlen_k == 128 ) { // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } else if( launch_params.params.seqlen_k >= 256 ) { // if (dprops->major == 8 && dprops->minor >= 0) { // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } else if (dprops->major == 7 && dprops->minor == 5) { // if (launch_params.is_dropout) { // Need to use the same block size as backward // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } else { // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } // } // } // } // if (launch_params.params.d == 128) { // if( launch_params.params.seqlen_k == 128 ) { // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } else { // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { // // TD [2022-06-05] Keep K in registers to reduce register spilling // // Gives about 6% speedup compared to using block size 128. // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } else { // Need to use the same block size as backward // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; // run_fmha_fp16_sm80_loop_(launch_params, configure); // } // } // } }); }