Commit 26f4b5fb authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Merge branch 'main' into Dao-AILab/main

parents 5018ac6a 12375706
Pipeline #2015 failed with stages
in 0 seconds
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "static_switch.h"
#include "flash.h"
#include "flash_bwd_preprocess_kernel.h"
#include "flash_bwd_kernel.h"
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash::convert_dKV<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
int gridDimx = num_n_block;
if (params.deterministic) {
auto dprops = at::cuda::getCurrentDeviceProperties();
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
}
dim3 grid_n(gridDimx, params.b, params.h);
if (!params.deterministic) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
} else {
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// a multiple of kBlockN, we'll need to apply mask in the loop.
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
}
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout, Is_causal>(params, stream);
#endif
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else { // 96 KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else {
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else {
// }
}
});
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) {
if constexpr(!Is_dropout) { // 92KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
} else { // 116 KB
// This is faster for dropout since we don't have many registers to spare
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
} else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 136 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
}
template<typename T, bool Is_causal>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout, Is_causal>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if constexpr (!Is_dropout) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false, Is_causal>(params, stream);
}
}
});
}
This diff is collapsed.
......@@ -159,6 +159,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
......@@ -579,16 +580,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
const index_t row_offset_k = block_table == nullptr
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
: (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread
const index_t row_offset_v = block_table == nullptr
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
......@@ -602,7 +604,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
......@@ -610,15 +611,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV;
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK);
Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV);
Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout()));
Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout()));
Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout()));
Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));
if (block_table != nullptr) {
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block_max, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
......@@ -656,8 +672,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout()));
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
......@@ -674,11 +691,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Prologue
// Copy from Knew to K, optionally apply rotary embedding.
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
if constexpr (Append_KV) {
typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
......@@ -695,10 +713,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.rotary_dim / 2, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout()));
Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout()));
Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout()));
Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout()));
// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
......@@ -721,8 +746,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
+ row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.vnew_row_stride, _1{}));
Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new;
auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx);
Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K)
Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
auto tKgKnew = make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout()));
auto tVgVnew = make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout()));
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
auto tKgK_data = tKgK.data();
......@@ -762,14 +792,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
if (n_block > n_block_copy_min) {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
}
}
......@@ -782,9 +808,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Read Q from gmem to smem, optionally apply rotary embedding.
if (!Append_KV || params.rotary_dim == 0) {
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
} else {
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
......@@ -819,7 +849,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence();
......@@ -858,17 +888,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
} else {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
cute::cp_async_fence();
......@@ -897,13 +924,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
......@@ -940,13 +964,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
} else {
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block + 1, params.page_block_size,
block_table, params.v_batch_stride, params.v_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm(
......@@ -964,13 +985,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
if (block_table == nullptr) {
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
} else {
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(tidx, n_block, params.page_block_size,
block_table, params.k_batch_stride, params.k_row_stride);
}
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
cute::cp_async_fence();
......
......@@ -131,6 +131,17 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
// from how many rows does each thread have to fetch
static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow);
// Here we assign a contiguous tile to each thread, rather than a 1x8 row every
// (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread
// do not cross a page boundary. This way, each thread need only fetch 1 page index per
// mainloop iteration. R>udimentary testing shows no slowdown.
using GmemTiledCopyQKVPaged = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
......@@ -156,6 +167,14 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
using GmemTiledCopyRotcossinPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinContPaged = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
......
......@@ -291,6 +291,53 @@ void cp_async_wait() {
////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template <typename Kernel_traits>
__forceinline__ __device__
int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size,
const int* block_table, const int page_stride, const int row_stride) {
constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
constexpr int kBlockN = Kernel_traits::kBlockN;
const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
const int64_t page_offset = global_row_offset % page_block_size;
const int64_t virtual_page_idx = global_row_offset / page_block_size;
return ((int64_t) block_table[virtual_page_idx]) * ((int64_t) page_stride)
+ page_offset * ((int64_t) row_stride)
+ col_offset;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k),
// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_thread_tile(Layout<Shape, Stride> l) {
return make_layout(append(get<0>(l.shape()), get<2>(l.shape())),
append(get<0>(l.stride()), get<2>(l.stride())));
}
// reshapes and flattens the thread tile layout. A separate function is needed for the case where
// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
// for the case of swizzled layouts
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_flatten_thread_tile(Layout<Shape, Stride> l) {
auto mode_0 = filter(flatten(get<0>(l)));
return make_layout(append(mode_0.shape(), get<2>(l.shape())),
append(mode_0.stride(), get<2>(l.stride())));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
......
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
class IndexFirstAxisResidual(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return output, input.detach()
@staticmethod
def backward(ctx, grad_output, grad_residual):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
assert grad_residual.shape[1:] == other_shape
grad_input = grad_residual
# grad_input[indices] += grad_output
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
indices = indices.expand_as(grad_output)
grad_input.scatter_add_(0, indices, grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis_residual = IndexFirstAxisResidual.apply
def unpad_input(hidden_states, attention_mask):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
"""
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
```
[
[2, 3, 0, 0, 0, 0],
[3, 2, 0, 0, 0, 0],
[6, 0, 0, 0, 0, 0]
]
```
, which refers to the 3D-attention mask:
```
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0],
[0, 0, 0, 0, 0, 1]
],
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1]
]
]
```.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
"""
length = attention_mask_in_length.sum(dim=-1)
seqlen = attention_mask_in_length.size(-1)
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1)
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
This diff is collapsed.
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
# for benchmarking.
# We fixed a few dtype cast to make it work for bf16
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import pytest
import torch
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
TMP,
L,
M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk *= sm_scale
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_i)
tl.store(m_ptrs, m_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out,
DO,
L,
NewDO,
Delta,
BLOCK_M: tl.constexpr,
D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q,
K,
V,
sm_scale,
Out,
DO,
DQ,
DK,
DV,
L,
M,
D,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
Z,
H,
N_CTX,
num_block,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, k, trans_b=True)
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
# # compute dq
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds.to(k.dtype), k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
# # increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty(
(q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q,
k,
v,
sm_scale,
tmp,
L,
m,
o,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
BLOCK_M=BLOCK,
BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.BLOCK = BLOCK
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)](
o,
do,
l,
do_scaled,
delta,
BLOCK_M=ctx.BLOCK,
D_HEAD=ctx.BLOCK_DMODEL,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps = 8
_bwd_kernel[(ctx.grid[1],)](
q,
k,
v,
ctx.sm_scale,
o,
do_scaled,
dq,
dk,
dv,
l,
m,
delta,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
q.shape[0],
q.shape[1],
q.shape[2],
ctx.grid[0],
BLOCK_M=ctx.BLOCK,
BLOCK_N=ctx.BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
)
return dq.to(q.dtype), dk, dv, None
attention = _attention.apply
import math
import hydra
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.flash_blocksparse_attn_interface import (
convert_blockmask,
flash_blocksparse_attn_func,
)
class FlashBlocksparseAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_temp: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def __init__(
self,
sparsity_config,
softmax_temp=None,
attention_dropout=0.0,
max_seq_length=2048,
device=None,
dtype=None,
):
super().__init__()
self.sparsity_config = hydra.utils.instantiate(sparsity_config)
self.softmax_temp = softmax_temp
self.dropout_p = attention_dropout
# initialize sparse layout and register as buffer
max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
layout = self.sparsity_config.make_layout(max_seq_length)
self.register_buffer("layout", layout)
blockmask_converted = convert_blockmask(self.layout, causal=False)
self.register_buffer("blockmask_converted", blockmask_converted)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def forward(
self,
qkv,
attn_mask=None,
key_padding_mask=None,
causal=False,
cu_seqlens=None,
max_s=None,
need_weights=False,
convert_mask=True,
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
"""
assert not need_weights
assert attn_mask is None
assert qkv.dtype == torch.float16
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], (
seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = seqlen
cu_seqlens = torch.arange(
0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device
)
output = flash_blocksparse_attn_func(
qkv,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
else:
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_blocksparse_attn_func(
x_unpad,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
),
"b s (h d) -> b s h d",
h=nheads,
)
else:
assert max_s is not None
seqlen = max_s
# Convert mask to take a subset
seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
assert seqlen_rounded // 16 <= self.layout.shape[0], (
seqlen_rounded // 256 <= self.layout.shape[1]
)
blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256]
if convert_mask:
output = flash_blocksparse_attn_func(
qkv,
cu_seqlens,
blockmask,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
)
else:
output = flash_blocksparse_attn_func(
qkv,
cu_seqlens,
self.blockmask_converted,
self.dropout_p if self.training else 0.0,
max_s,
softmax_scale=self.softmax_temp,
causal=causal,
convert_mask=False,
)
return output, None
class FlashBlocksparseMHA(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
sparsity_config,
bias=True,
batch_first=True,
attention_dropout=0.0,
causal=False,
max_seq_length=2048,
device=None,
dtype=None,
**kwargs,
) -> None:
assert batch_first
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashBlocksparseAttention(
sparsity_config,
attention_dropout=attention_dropout,
max_seq_length=max_seq_length,
**factory_kwargs,
)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(
self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False
):
qkv = self.Wqkv(x)
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
)
return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import flash_attn_cuda
import torch
import torch.nn as nn
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
0 means the block is skipped.
nonzero means the block is not skipped.
Argument:
blockmask: (row, col): a 0-1 tensor
Return:
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
indices of the nonzero blocks, padded with -1 to reach length @row.
The indices are multiplied by 4, with the smallest bit used to encode whether
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
the last nonzero in its row..
"""
assert not causal
# TD [2022-05-13]: The indexing and sorting is very tricky
nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
]
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
]
nonzero_idx = nonzero_sorted_rowidx * 4
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
nonzero_idx[nonzero_val == 0] = -1
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def _flash_blocksparse_attn_forward(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax
):
context, softmax_lse, *rest = flash_attn_cuda.fwd_block(
qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return context, softmax_lse, S_dmask
def _flash_blocksparse_attn_backward(
dout,
qkv,
out,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal,
):
dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(
dout,
qkv,
out,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
dropout_p,
softmax_scale,
max_s,
causal,
None,
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dqkv
class FlashBlocksparseAttnFun(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=False,
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context
@staticmethod
def backward(ctx, dout):
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv = _flash_blocksparse_attn_backward(
dout,
qkv,
context,
context,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale,
causal=causal,
return_softmax=True,
)
ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
ctx.dropout_p = dropout_p
ctx.max_s = max_s
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return context, S_dmask, softmax_lse
@staticmethod
def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = _flash_blocksparse_attn_backward(
dout,
qkv,
context,
S_dmask,
softmax_lse,
cu_seqlens,
blockmask,
ctx.dropout_p,
ctx.max_s,
ctx.softmax_scale,
ctx.causal,
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None
def flash_blocksparse_attn_func(
qkv,
cu_seqlens,
blockmask,
dropout_p,
max_s,
softmax_scale=None,
causal=False,
return_attn_probs=False,
convert_mask=True,
):
"""dropout_p should be set to 0.0 during evaluation"""
func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
if convert_mask:
blockmask = convert_blockmask(blockmask, causal=causal)
return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)
# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py
# for benchmarking.
# We added support for seqlen=2k and seqlen=4k
# coding=utf-8
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex._autocast_utils import _cast_if_autocast_enabled
from apex.transformer.enums import AttnMaskType
from fused_softmax_lib import (
scaled_masked_softmax_backward,
scaled_masked_softmax_forward,
scaled_masked_softmax_get_batch_per_block,
scaled_upper_triang_masked_softmax_backward,
scaled_upper_triang_masked_softmax_forward,
)
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
def scaled_upper_triang_masked_softmax(inputs, _, scale):
b, np, sq, sk = inputs.size()
assert sq == sk, "causal mask is only for self attention"
# Reshaping input to 3D tensor (attn_batches, sq, sk)
inputs = inputs.view(-1, sq, sk)
args = _cast_if_autocast_enabled(inputs, scale)
with torch.cuda.amp.autocast(enabled=False):
probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
return probs.view(b, np, sq, sk)
# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
# So I needed to manually write two `torch.autograd.Function` inheritances.
# Fused operation which performs following three operations in sequence
# 1. Scale the tensor.
# 2. Apply the mask.
# 3. Perform softmax.
class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
def scaled_masked_softmax(inputs, mask, scale):
# input is 4D tensor (b, np, sq, sk)
args = _cast_if_autocast_enabled(inputs, mask, scale)
with torch.cuda.amp.autocast(enabled=False):
return ScaledMaskedSoftmax.apply(*args)
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super().__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
if self.input_in_fp16 and self.input_in_bf16:
raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.")
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
if not (self.scale is None or softmax_in_fp32):
raise RuntimeError("softmax should be in fp32 when scaled")
if self.scaled_masked_softmax_fusion:
if self.attn_mask_type == AttnMaskType.causal:
self.fused_softmax_func = scaled_upper_triang_masked_softmax
elif self.attn_mask_type == AttnMaskType.padding:
self.fused_softmax_func = scaled_masked_softmax
else:
raise ValueError("Invalid attn_mask_type.")
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and (
self.attn_mask_type == AttnMaskType.causal
or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
)
and 16 < sk <= 8192 # sk must be 16 ~ 8192
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 8192:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
# input.shape = [b, np, sq, sk]
scale = self.scale if self.scale is not None else 1.0
return self.fused_softmax_func(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment