Commit a4f148b6 authored by Tri Dao's avatar Tri Dao
Browse files

Fix masking of bwd when seqlen is not divisible by 128

parent 184b992d
...@@ -415,7 +415,7 @@ inline __device__ void convert_dKV(const Params &params) { ...@@ -415,7 +415,7 @@ inline __device__ void convert_dKV(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) { inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
...@@ -436,7 +436,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -436,7 +436,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
const BlockInfo</*Varlen=*/!Is_even_M> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return; if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
...@@ -668,10 +668,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -668,10 +668,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll #pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
return; return;
...@@ -687,7 +687,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -687,7 +687,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (Kernel_traits::Is_V_in_regs) { if (Kernel_traits::Is_V_in_regs) {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::cp_async_fence(); flash::cp_async_fence();
...@@ -697,18 +697,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -697,18 +697,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tdOrO = make_fragment_like(tdOgO); Tensor tdOrO = make_fragment_like(tdOgO);
if (!Is_first) { if (!Is_first) {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_thr_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
} else { } else {
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
); );
...@@ -722,7 +722,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -722,7 +722,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int mi = 0; mi < size(lse); ++mi) { for (int mi = 0; mi < size(lse); ++mi) {
// Using uint32_t row makes it 10us slower on d=128, not sure why. // Using uint32_t row makes it 10us slower on d=128, not sure why.
const int row = get<0>(taccScS_row(mi)); const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_M || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
} }
// Tensor tKrK = make_fragment_like(tKsK); // Tensor tKrK = make_fragment_like(tKsK);
...@@ -730,11 +730,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -730,11 +730,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// copy(gmem_thr_copy_QKV, tKgK, tKrK); // copy(gmem_thr_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); } // // if (cute::thread(1, 0)) { print(tKrK); }
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
if (!Kernel_traits::Is_V_in_regs) { if (!Kernel_traits::Is_V_in_regs) {
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
...@@ -783,15 +783,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -783,15 +783,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
// We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// be some finite value for those indices. In the end when we multiply with K to get dQ, // actual_seqlen_k, because acc_s would be some finite value for those indices.
// the corresponding values of K would be 0, so the result would still be correct. // In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// Putting this causal masking right after acc_s is *much* slower for some reason. // so the result would still be correct.
if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { // However, it's possible that the values in acc_s are so large that they overflow
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), // So we need to mask out the elements beyond actual_seqlen_k.
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, if (!Is_causal) {
AtomLayoutMS * 16); if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
}
} else {
// Putting this causal masking right after acc_s is *much* slower for some reason.
if (m_block * kBlockM < (n_block + 1) * kBlockN) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
} }
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value. // Compute the exponential value.
...@@ -978,7 +989,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -978,7 +989,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
#pragma unroll #pragma unroll
for (int m = 0; m < size<1>(tdQgdQ); ++m) { for (int m = 0; m < size<1>(tdQgdQ); ++m) {
if (Is_even_M || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
copy(gmem_thr_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _)); copy(gmem_thr_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));
} }
} }
...@@ -1044,10 +1055,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in ...@@ -1044,10 +1055,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll #pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
...@@ -1487,7 +1498,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) { ...@@ -1487,7 +1498,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int n_block = blockIdx.x; const int n_block = blockIdx.x;
...@@ -1496,7 +1507,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) { ...@@ -1496,7 +1507,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.z; const int bidh = blockIdx.z;
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block); compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -23,9 +23,9 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { ...@@ -23,9 +23,9 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params); flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { __global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params); flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
...@@ -53,17 +53,17 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -53,17 +53,17 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params); flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
// We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// for cu_seqlens_q as well. // a multiple of kBlockN, we'll need to apply mask in the loop.
const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_M, IsEvenMConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, true>; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) { if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
...@@ -102,7 +102,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream, ...@@ -102,7 +102,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>; auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, true, true>; // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) { if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute( C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
......
...@@ -117,15 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens ...@@ -117,15 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
} }
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) { inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k,
const uint32_t col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const uint32_t lane_id = threadIdx.x % 32;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; const uint32_t col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) { if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results // Without the "make_coord" we get wrong results
#pragma unroll #pragma unroll
......
...@@ -825,3 +825,97 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): ...@@ -825,3 +825,97 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0]) assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0])
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1]) assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2]) assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
@pytest.mark.parametrize('seqlen', [1, 2, 5, 17, 128])
# @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0.
"""
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 5
q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2)]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
out = flash_attn_func(q, k, v, causal=causal)
g = torch.randn_like(out)
out.backward(g)
q_pt = q.detach().clone().requires_grad_(True)
k_pt = k.detach().clone().requires_grad_(True)
v_pt = v.detach().clone().requires_grad_(True)
out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
out_pt.backward(g)
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (q_pt.grad - q_ref.grad).abs().max().item() + 1e-3
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (k_pt.grad - k_ref.grad).abs().max().item() + 1e-3
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (v_pt.grad - v_ref.grad).abs().max().item() + 1e-3
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256])
# @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
""" We previously had a bug where we were using the wrong strides of dout, which shows up
when dout is not contiguous.
"""
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 5
nheads = 2
q, k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda",
requires_grad=True)
for _ in range(3)]
out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
# So g is not contiguous
g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
out.backward(g)
q_pt = q.detach().clone().requires_grad_(True)
k_pt = k.detach().clone().requires_grad_(True)
v_pt = v.detach().clone().requires_grad_(True)
out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
out_pt = rearrange(out_pt, "b s ... -> s b ...")
out_pt.backward(g)
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref = rearrange(out_ref, "b s ... -> s b ...")
out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (q_pt.grad - q_ref.grad).abs().max().item()
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (k_pt.grad - k_ref.grad).abs().max().item()
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (v_pt.grad - v_ref.grad).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment