Commit 34e67b1e authored by zhangshao's avatar zhangshao
Browse files

first commit

parents
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 32, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 32, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 512, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim512<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 512, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim512<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 512, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim512<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 512, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim512<cutlass::half_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 64, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 64, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim64<cutlass::half_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 96, true>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t, true>(params, stream);
}
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_bwd_launch_template.h"
template<>
void run_mha_bwd_<cutlass::half_t, 96, false>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim96<cutlass::half_t, false>(params, stream);
}
This source diff could not be displayed because it is too large. You can view the blob instead.
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
#if defined(DCU_ASM)
#define ARCH_SUPPORTS_FLASH
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dk_dv_trans_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dk_dv_trans_seqk_parallel_16x64<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dk_dv_trans_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dk_dv_trans_16x64_mla_prefetch, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dk_dv_trans_16x64_mla_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_loop_seqq_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_loop_16x64_seqq_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_seqq_parallel_16x64<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_seqq_parallel_16x64_prefetch<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash::convert_dKV<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
int gridDimx = num_n_block;
// if (params.deterministic) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
// }
dim3 grid_n(gridDimx, params.b, params.h);
// printf("run_flash_bwd_seqk_parallel: grid_m=%d, %d, %d, \n", grid_m.x, grid_m.y, grid_m.z);
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// printf("flash_bwd_dot_do_o_kernel done, params.deterministic=%d, params.seqlen_q=%d, params.seqlen_k=%d, \n",
// params.deterministic, params.seqlen_q, params.seqlen_k);
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// a multiple of kBlockN, we'll need to apply mask in the loop.
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
// printf("run_flash_bwd_seqk_parallel: grid_n=%d, %d, %d, \n", grid_n.x, grid_n.y, grid_n.z);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
// if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
// }
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
#endif
}
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
int gridDimx = num_n_block;
// if (params.deterministic) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
// }
dim3 grid_n(gridDimx, params.b, params.h);
// printf("run_flash_bwd_seqk_parallel: grid_m=%d, %d, %d, \n", grid_m.x, grid_m.y, grid_m.z);
// if (!params.deterministic) {
// flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
// } else {
// }
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// printf("flash_bwd_dot_do_o_kernel done, params.deterministic=%d, params.seqlen_q=%d, params.seqlen_k=%d, \n",
// params.deterministic, params.seqlen_q, params.seqlen_k);
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// a multiple of kBlockN, we'll need to apply mask in the loop.
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
#ifdef BWDTRANS
constexpr int smem_size_dq_dk_dv = Kernel_trans_traits::kSmemSizeTrans1colblock;
#else
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
#endif
constexpr int smem_size_dq = Kernel_traits::kSmemSize1rowblock;
// printf("smem_size_dq_dk_dv = %d smem_size_dq = %d\n", smem_size_dq_dk_dv, smem_size_dq);
// printf("run_flash_bwd_seqk_parallel: grid_n=%d, %d, %d, \n", grid_n.x, grid_n.y, grid_n.z);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
#ifdef BWDTRANS
auto kernel = &flash_bwd_dk_dv_trans_loop_seqk_parallel_kernel<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
#else
auto kernel = &flash_bwd_dk_dv_loop_seqk_parallel_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
#endif
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto kernel_dq = flash_bwd_dq_loop_seqq_parallel_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
#endif
}
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_seqk_parallel_trans(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.h, params.b);
dim3 grid_m_do(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
// if (params.deterministic) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
// }
dim3 grid_n(num_n_block, params.h, params.b);
// printf("run_flash_bwd_seqk_parallel: grid_m=%d, %d, %d, \n", grid_m.x, grid_m.y, grid_m.z);
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// printf("flash_bwd_dot_do_o_kernel done, params.deterministic=%d, params.seqlen_q=%d, params.seqlen_k=%d, \n",
// params.deterministic, params.seqlen_q, params.seqlen_k);
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// a multiple of kBlockN, we'll need to apply mask in the loop.
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
// printf("is_even_MN = %d Kernel_traits::kBlockN = %d params.seqlen_k = %d\n", is_even_MN, Kernel_traits::kBlockN, params.seqlen_k);
// printf("is_even_MN = %d Kernel_traits::kBlockN = %d\n", is_even_MN, Kernel_traits::kBlockN);
#if 1
constexpr int smem_size_dq_dk_dv = Kernel_trans_traits::kSmemSizeTrans1colblock;
#else
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
#endif
constexpr int smem_size_dq = Kernel_traits::kSmemSize1rowblock;
// printf("smem_size_dq_dk_dv = %d smem_size_dq = %d\n", smem_size_dq_dk_dv, smem_size_dq);
// printf("run_flash_bwd_seqk_parallel: grid_n=%d, %d, %d, \n", grid_n.x, grid_n.y, grid_n.z);
BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// // If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dk_dv_trans_16x64_loop_seqk_parallel_kernel<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto kernel_dq = flash_bwd_dq_loop_16x64_seqq_parallel_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
#endif
}
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_prefetch(Flash_bwd_params &params, cudaStream_t stream) {
// const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
#ifndef FLASHATTENTION_DISABLE_BACKWARD
#ifdef NO_CAUSAL_OPT
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
#else
const int non_causal_num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
const int num_n_block = (Is_causal && Kernel_trans_traits::kHeadDim != 64) ? (non_causal_num_n_block + 1 ) >> 1 :
non_causal_num_n_block;
const int non_causal_num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int num_m_block = Is_causal ? (non_causal_num_m_block + 1 ) >> 1 :
non_causal_num_m_block;
#endif
dim3 grid_m(num_m_block, params.h, params.b);
dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h);
dim3 grid_n(num_n_block, params.h, params.b);
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dropout = Kernel_trans_traits::kBlockM * Kernel_trans_traits::kBlockN;
constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
constexpr int smem_size_dk_dv_total = (Kernel_trans_traits::kHeadDim == 128) ? (smem_size_dk_dv + smem_size_dropout) : (smem_size_dk_dv);
constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
// constexpr static bool IsEvenMNConst = false;
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
// constexpr static bool IsEvenKConst = true;
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
// constexpr static bool Is_local = false;
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// constexpr static bool Has_alibi = false;
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// constexpr static bool Is_softcap = false;
BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
auto kernel = &flash_bwd_dk_dv_trans_16x64_prefetch<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv_total, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
auto kernel_dq = flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
#endif
}
template<typename Kernel_traits, typename Kernel_trans_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_separate_mla_prefetch(Flash_bwd_params &params, cudaStream_t stream) {
// const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
#ifndef FLASHATTENTION_DISABLE_BACKWARD
// #ifdef NO_CAUSAL_OPT
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
// #else
// const int non_causal_num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
// const int num_n_block = Is_causal ? (non_causal_num_n_block + 1 ) >> 1 :
// non_causal_num_n_block;
// const int non_causal_num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
// const int num_m_block = Is_causal ? (non_causal_num_m_block + 1 ) >> 1 :
// non_causal_num_m_block;
// #endif
dim3 grid_m(num_m_block, params.h, params.b);
dim3 grid_m_do((params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM, params.b, params.h);
dim3 grid_n(num_n_block, params.h, params.b);
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m_do, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_MN_trans = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_trans_traits::kBlockM == 0 && params.seqlen_k % Kernel_trans_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dk_dv = Kernel_trans_traits::kSmemPrefetchSize;
constexpr int smem_size_dq = Kernel_traits::kSmemPrefetchSize;
BOOL_SWITCH(is_even_MN_trans, IsEvenMNTransConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
auto kernel = &flash_bwd_dk_dv_trans_16x64_mla_prefetch<
Kernel_trans_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNTransConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
auto kernel_dq = flash_bwd_dq_loop_16x64_prefetch_seqq_parallel_kernel<
Kernel_traits,
Is_dropout && !Is_softcap, Is_causal,
Is_local && !Is_causal,
Has_alibi,
IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128,
IsEvenKConst,
Is_softcap>;
kernel_dq<<<grid_m, Kernel_traits::kNThreads, smem_size_dq, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
#endif
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout, Is_causal>(params, stream);
#endif
}
template<typename Kernel_dq_traits, typename Kernel_dkdv_traits, bool Is_dropout, bool Is_causal>
void run_flash_separate_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_separate_seqk_parallel<Kernel_dq_traits, Kernel_dkdv_traits, Is_dropout, Is_causal>(params, stream);
#endif
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
// printf("run_mha_bwd_hdim32..\n");
constexpr static int Headdim = 32;
// int device;
// cudaGetDevice(&device);
// int max_smem_per_block;
// cudaError status_ = cudaDeviceGetAttribute(
// &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// if (status_ != cudaSuccess) {
// C10_CUDA_CHECK(status_);
// }
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
// if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// } else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// }
// } else { // 96 KB
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// }
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 32, 4, 1, 1, 1, true, true, T>, Is_dropout, Is_causal>(params, stream);
#ifdef BWDSEPARATE
using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
#else
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
#endif
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
// printf("run_mha_bwd_hdim64..\n");
constexpr static int Headdim = 64;
// int device;
// cudaGetDevice(&device);
// int max_smem_per_block;
// cudaError status_ = cudaDeviceGetAttribute(
// &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// if (status_ != cudaSuccess) {
// C10_CUDA_CHECK(status_);
// }
// // printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938")
{
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/128, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>;
run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
}
else
{
using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/128, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim96<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, T, 3>;
run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
} else {
#ifdef BWDSEPARATE
using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
#else
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
#endif
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938"){
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3>;
// using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? 64 : 128, /*kBlockN_*/64, /*kNWarps_*/4,
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/Is_dropout ? (Is_causal ? 64 : 128) : 128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3>;
run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
} else {
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
// run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
if constexpr (std::is_same_v<T, cutlass::bfloat16_t>) {
using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
}
else {
using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/128, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
}
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
// printf("run_mha_bwd_hdim160..\n");
constexpr static int Headdim = 160;
// int device;
// cudaGetDevice(&device);
// int max_smem_per_block;
// cudaError status_ = cudaDeviceGetAttribute(
// &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// if (status_ != cudaSuccess) {
// C10_CUDA_CHECK(status_);
// }
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (max_smem_per_block >= 116 * 1024) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// } else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
// }
#ifdef BWDSEPARATE
using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/32, /*kNWarps_*/2,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
#else
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
#endif
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim192_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
#if 1
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
// using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_mla_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4, T, 3, 128>;
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 3, 128>;
run_flash_bwd_separate_mla_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
} else {
//static_assert(0, "FA headdim 192 128 only support BW\n");
}
// using kernel_traits = Flash_bwd_kernel_dq_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
// /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/4, /*Is_V_in_regs_*/false,
// /*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T, 128>;
// using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/64, /*kNWarps_*/4,
// /*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
// /*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T, 128>;
// // run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_bwd_separate_seqk_parallel_trans<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
});
#endif
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
// printf("run_mha_bwd_hdim192..\n");
constexpr static int Headdim = 192;
// int device;
// cudaGetDevice(&device);
// int max_smem_per_block;
// cudaError status_ = cudaDeviceGetAttribute(
// &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// if (status_ != cudaSuccess) {
// C10_CUDA_CHECK(status_);
// }
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (max_smem_per_block >= 136 * 1024) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
// } else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
// }
#ifdef BWDSEPARATE
using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/32, /*kNWarps_*/2,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
#else
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
#endif
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
// printf("run_mha_bwd_hdim224..\n");
constexpr static int Headdim = 224;
// int device;
// cudaGetDevice(&device);
// int max_smem_per_block;
// cudaError status_ = cudaDeviceGetAttribute(
// &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
// if (status_ != cudaSuccess) {
// C10_CUDA_CHECK(status_);
// }
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// if (max_smem_per_block >= 136 * 1024) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
// } else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
// }
#ifdef BWDSEPARATE
using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/32, /*kNWarps_*/2,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
#else
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
#endif
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
// printf("%s:%d\n", __FILE__, __LINE__);
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim256<Headdim, 64, 64, 4, T, 3>;
run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
} else {
#ifdef BWDSEPARATE
using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/16, /*kNWarps_*/2,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
#else
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
#endif
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim512(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 512;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (get_device_name() == "gfx936" || get_device_name() == "gfx938") {
// printf("%s:%d\n", __FILE__, __LINE__);
using kernel_traits = Flash_bwd_kernel_dq_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
using kernel_trans_traits = Flash_bwd_kernel_trans_16x64_prefetch_traits_dim512<Headdim, 64, 64, 4, T, 3>;
run_flash_bwd_separate_prefetch<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
} else {
#ifdef BWDSEPARATE
using kernel_traits = Flash_bwd_kernel_dq_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/16, /*kNWarps_*/2,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/1, /*AtomLayoutMdQ*/2, /*Is_V_in_regs_*/false,
/*No_double_buffer_*/true, /*Is_Q_in_regs_*/false, /*Share_Q_K_smem_*/true, T>;
using kernel_trans_traits = Flash_bwd_kernel_trans_traits<Headdim, /*kBlockM_*/32, /*kBlockN_*/64, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/4, /*AtomLayoutNdKV*/4, /*AtomLayoutMdQ*/1,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>;
run_flash_separate_bwd<kernel_traits, kernel_trans_traits, Is_dropout, Is_causal>(params, stream);
// run_flash_separate_bwd<dq_traits, Is_dropout, Is_causal>(params, stream);
#else
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, /*kBlockM_*/64, /*kBlockN_*/32, /*kNWarps_*/4,
/*AtomLayoutMSdP_*/2, /*AtomLayoutNdKV*/2, /*AtomLayoutMdQ*/4,
/*Is_V_in_regs_*/true, /*No_double_buffer_*/true, T>, Is_dropout, Is_causal>(params, stream);
#endif
}
});
}
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "block_info.h"
#include "kernel_traits.h"
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
// The last coordinate is the "page".
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
make_layout(get<0>(do_.layout()),
get<2>(do_.layout()))));
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
Tensor do_fp32 = flash::convert_type<float>(do_reshaped);
Tensor o_fp32 = flash::convert_type<float>(o_reshaped);
#pragma unroll
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
#pragma unroll
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
}
flash::SumOp<float> sum_op;
dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
if (threadIdx.x % THREADS_PER_ROW == 0) {
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM;
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
// TODO: careful, we're zeroing out dQaccum with type float4, but when
// we do atomicAdds, we use type float. The layouts are different. Check this.
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
#if 1
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
// Allocate predicate tensors for k
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
// Set predicates for k bounds
#pragma unroll
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d_value;}
Tensor tdOrdO = make_fragment_like(tdOgdO);
Tensor tdOrO = make_fragment_like(tdOgO);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// so that (dP - dP_sum) is on the same scale.
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
if (Clear_dQaccum) {
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
// do atomicAdds on.
Tensor zero = make_fragment_like(tdQgdQaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void clear_dKVaccum(const Params &params) {
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
Tensor zero = make_fragment_like(tdKgdKaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dQ from dQaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dQ(const Params &params, const int nsplits) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdQ{});
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
clear(acc_dq);
for (int s = 0; s < nsplits; ++s) {
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
}
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
__syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
#pragma unroll
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_q.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dKV(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
+ n_block * kBlockN) * params.d_rounded;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdKV{});
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
#pragma unroll
for (int i = 0; i < size(acc_dk); ++i) {
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
}
#pragma unroll
for (int i = 0; i < size(acc_dv); ++i) {
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
}
// Convert acc_dk from fp32 to fp16
Tensor rdK = flash::convert_type<Element>(acc_dk);
Tensor rdV = flash::convert_type<Element>(acc_dv);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
__syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
} // namespace flash
\ No newline at end of file
// Copyright (c) 2026, Attnmask extension.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_attnmask_launch_template.h"
template void run_mha_fwd_attnmask_<cutlass::bfloat16_t, 128, true>(
Flash_fwd_params_attnmask &params, cudaStream_t stream);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment