Commit 4f285b35 authored by Tri Dao's avatar Tri Dao
Browse files

FlashAttention-2 release

parent 6d48e14a
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // For dropout there might be a lot of register spilling?
// // These two are very slow due to register spilling
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
// // This one is slightly slower
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// // Using block size (64 x 256) is 27% slower for seqlen=2k
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
// }
// }
template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
}
\ No newline at end of file
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
// // This 3rd one is good for H100, and A100, A6000
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
// // These two are always slower
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
/* Copyright (c) 2022, Tri Dao.
*/
#include "fmha.h"
#include "fmha_block_dgrad_kernel_1xN_loop.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_block_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}
template<typename Kernel_traits>
void run_fmha_block_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
static_assert(smem_size_dp_sum == 16 * 4 * 2);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum;
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal;
auto kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) {
kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
} else if (params.seqlen_k == blocksize_c * 2) {
kernel = is_dropout
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
}
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
if (params.d == 16) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (params.d == 32) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (params.d == 64) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment