Commit a5a8806d authored by Tri Dao's avatar Tri Dao
Browse files

Split bwd on the seqlen_q dimension

parent 871db479
...@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
int blocksize_c = (head_size == 128 && (!is_sm80)) ? 128 : 256; int blocksize_c = head_size == 128 ? 128 : 256;
// Need to round max_seqlen_k to multiples of blocksize_c // Need to round max_seqlen_k to multiples of blocksize_c
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c; int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
if( max_seqlen_k_ <= 128 ) { if( max_seqlen_k_ <= 128 ) {
...@@ -332,6 +332,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -332,6 +332,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale, const float softmax_scale,
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
const int num_splits,
c10::optional<at::Generator> gen_ c10::optional<at::Generator> gen_
) { ) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
...@@ -447,7 +448,22 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -447,7 +448,22 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal, is_causal,
/*num_splits=*/1); num_splits);
launch(params, stream, /*configure=*/true);
at::Tensor dk_accum, dv_accum;
if (params.num_splits > 1) {
// dk_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// dv_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// params.dk_accum_ptr = dk_accum.data_ptr();
// params.dv_accum_ptr = dv_accum.data_ptr();
dk.zero_();
dv.zero_();
} else {
// params.dk_accum_ptr = nullptr;
// params.dv_accum_ptr = nullptr;
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -461,7 +477,12 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -461,7 +477,12 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params.philox_args = gen->philox_cuda_state(counter_offset); params.philox_args = gen->philox_cuda_state(counter_offset);
} }
launch(params, stream); launch(params, stream, /*configure=*/false);
// if (params.num_splits > 1) {
// dk.copy_(dk_accum);
// dv.copy_(dv_accum);
// }
return { dq, dk, dv, softmax_d }; return { dq, dk, dv, softmax_d };
} }
......
...@@ -140,6 +140,10 @@ struct FMHA_dgrad_params : public FMHA_fprop_params { ...@@ -140,6 +140,10 @@ struct FMHA_dgrad_params : public FMHA_fprop_params {
void *__restrict__ dk_ptr; void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr; void *__restrict__ dv_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension
// void *__restrict__ dk_accum_ptr;
// void *__restrict__ dv_accum_ptr;
// The stride between rows of the dQ, dK and dV matrices. // The stride between rows of the dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers. // TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB. // The code probably won't work for arrays larger than 2GB.
...@@ -193,7 +197,7 @@ struct Launch_params{ ...@@ -193,7 +197,7 @@ struct Launch_params{
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params); void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream); void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure); void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
......
...@@ -28,6 +28,9 @@ ...@@ -28,6 +28,9 @@
#pragma once #pragma once
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <fmha/utils.h>
namespace fmha { namespace fmha {
...@@ -41,7 +44,8 @@ template< ...@@ -41,7 +44,8 @@ template<
// The number of rows of Q, K or V loaded by this tile. // The number of rows of Q, K or V loaded by this tile.
int ROWS_, int ROWS_,
// The number of columns. // The number of columns.
int COLS int COLS,
int BYTES_PER_LDGS_ = 16
> >
struct Gmem_tile_qkv { struct Gmem_tile_qkv {
...@@ -49,7 +53,7 @@ struct Gmem_tile_qkv { ...@@ -49,7 +53,7 @@ struct Gmem_tile_qkv {
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8; static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
// The size of each LDG. // The size of each LDG.
static constexpr int BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_;
// The size of a row in bytes. // The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8; static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;
...@@ -130,6 +134,42 @@ struct Gmem_tile_qkv { ...@@ -130,6 +134,42 @@ struct Gmem_tile_qkv {
} }
} }
template <typename elem_type>
inline __device__ void atomic_add(const uint4 (&data)[LDGS]) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
}
}
}
}
// Not being used. This only supports converting from fp16 -> fp32 for now (not bf16 -> fp32).
inline __device__ void atomic_add_float(const uint4 (&data)[LDGS]) {
static_assert(BYTES_PER_ELEMENT == 4); // Only support fp32
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]);
atomicAdd(ptr_ + jj * 2, data_f.x);
atomicAdd(ptr_ + jj * 2 + 1, data_f.y);
}
}
}
}
inline __device__ void move(const int steps = 1) { inline __device__ void move(const int steps = 1) {
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps; // ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps; ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
......
...@@ -76,6 +76,12 @@ struct FMHA_kernel_traits { ...@@ -76,6 +76,12 @@ struct FMHA_kernel_traits {
using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>; using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// // The global memory tile to store the accumulated dK and dV
// // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV
// // where there are 16 bits per lements and 16 bytes per load. In reality we won't
// // be issue any load or store of size 32 bytes.
// using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv<Cta_tile_o, 32, S, D, 32>;
// The global memory tile to store the softmax sum. // The global memory tile to store the softmax sum.
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>; using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
......
...@@ -6,13 +6,45 @@ ...@@ -6,13 +6,45 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h" #include "fmha_dgrad_kernel_1xN_loop.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// Moreover, more than 1 split incurs extra cost of zeroing out dk/dv and doing atomic add
// instead of just writing.
// So for num_splits > 1, we divide the efficiency by some factor (e.g. 1.25, depending on seqlen)
// to account for this. Moreover, more splits means atomic add will be slower.
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits,
int seqlen, bool is_causal) {
float max_efficiency = 0.f;
int best_num_splits = 1;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
float discount_factor = 1.f + 512.0 / seqlen; // 1.25 for seqlen 2k, 1.125 for 4k.
discount_factor *= is_causal ? 1.1 : 1.f; // causal makes it even slower.
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff_raw = n_waves / ceil(n_waves);
// Heuristic: each increase in num_splits results in 6% slowdown, up to maybe 8 splits.
float eff = num_splits == 1 ? eff_raw : (eff_raw - 0.07 * std::min(num_splits - 2, 6)) / discount_factor;
// printf("num_splits = %d, eff_raw = %f, eff = %f\n", num_splits, eff_raw, eff);
if (eff > max_efficiency) {
max_efficiency = eff;
best_num_splits = num_splits;
}
efficiency.push_back(eff);
}
// printf("num_splits chosen = %d\n", best_num_splits);
return best_num_splits;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { __global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params); fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
} }
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_t stream) { void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
...@@ -46,41 +78,58 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_ ...@@ -46,41 +78,58 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
} }
dim3 grid(params.b, params.h); // Automatically set num_splits to maximize occupancy
if (params.num_splits <= 0) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size_dq_dk_dv);
auto dprops = at::cuda::getCurrentDeviceProperties();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr int M = Kernel_traits::Cta_tile_p::M;
// We don't want more than 10 splits due to numerical error.
// Numerical error on dk/dv scales as sqrt(num_splits).
params.num_splits = num_splits_heuristic_bwd(
params.b * params.h, dprops->multiProcessorCount,
ctas_per_sm, /*max_splits=*/std::min(10, (params.seqlen_q + M - 1 / M)),
params.seqlen_k, params.is_causal
);
}
if (configure) return;
dim3 grid(params.b, params.h, params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params); kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
}); });
} }
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) { void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue // work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] { FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.d == 16) { if (params.d == 16) {
if( params.seqlen_k == 128 ) { if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k == 256 ) { } else if( params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else { } else {
// TD [2022-05-15] 512 gives wrong results rn // TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} }
} else if (params.d == 32) { } else if (params.d == 32) {
if( params.seqlen_k == 128 ) { if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k >= 256 ) { } else if( params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} }
} else if (params.d == 64) { } else if (params.d == 64) {
if( params.seqlen_k == 128 ) { if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k >= 256 ) { } else if( params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor == 0) { if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers // Don't share smem for K & V, and don't keep V in registers
...@@ -88,45 +137,18 @@ void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stre ...@@ -88,45 +137,18 @@ void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stre
// uses more shared memory, which is fine on A100 but not other GPUs. // uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers. // For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 8 && dprops->minor > 0) { } else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 7 && dprops->minor == 5) { } else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} }
} }
} else if (params.d == 128) { } else if (params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} }
// if (params.d == 64) {
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u_elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
}); });
} }
\ No newline at end of file
...@@ -135,6 +135,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -135,6 +135,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
// How many steps to jump per iteration, which is the same as params.num_splits.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx); const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return; // if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
...@@ -184,18 +187,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -184,18 +187,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin; const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
// Wind gmem tiles to the correct position. // Wind gmem tiles to the correct position.
gmem_q.move(begin); gmem_q.move(begin + blockIdx.z);
gmem_do.move(begin); gmem_do.move(begin + blockIdx.z);
gmem_o.move(begin); gmem_o.move(begin + blockIdx.z);
gmem_dq.move(begin); gmem_dq.move(begin + blockIdx.z);
gmem_dq_tmp.move(begin); gmem_dq_tmp.move(begin + blockIdx.z);
// TODO: need to move gmem_s if we want the intermediate result for debugging // TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse.move(begin); gmem_softmax_lse.move(begin + blockIdx.z);
gmem_softmax_d.move(begin); gmem_softmax_d.move(begin + blockIdx.z);
if (!Is_first) { if (!Is_first) {
gmem_k.move(loop_step_idx); gmem_k.move(loop_step_idx);
...@@ -215,7 +224,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -215,7 +224,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
float p_lse[Mma_tile_p::MMAS_M * 2]; float p_lse[Mma_tile_p::MMAS_M * 2];
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse)); gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
if (!Is_first) { __syncthreads(); } if (!Is_first) { __syncthreads(); }
// Commit the data for Q, dO, and V to shared memory. // Commit the data for Q, dO, and V to shared memory.
...@@ -265,7 +273,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -265,7 +273,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
float dp_sum[Mma_tile_p::MMAS_M * 2]; float dp_sum[Mma_tile_p::MMAS_M * 2];
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum)); gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
// Commit the data for V to shared memory if it has not been done already. // Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
...@@ -301,9 +308,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -301,9 +308,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk); fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length. // Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) { for (int l = blockIdx.z; l < steps; l += step_stride) {
const int loop = (begin + l) * Cta_tile_p::M; if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q)
if( loop >= binfo.actual_seqlen_q )
break; break;
// Load the fragments for V. // Load the fragments for V.
...@@ -352,9 +358,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -352,9 +358,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_s.store(frag_p); smem_s.store(frag_p);
// Trigger the load for the next Q values. // Trigger the load for the next Q values.
if( l < steps - 1) { if (l + step_stride < steps) {
gemm_q_k.smem_q.move_to_next_write_buffer(); gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move(); gmem_q.move(step_stride);
gmem_q.load(); gmem_q.load();
} }
...@@ -427,12 +433,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -427,12 +433,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_kt.load(frag_kt[0], 0); smem_kt.load(frag_kt[0], 0);
// Trigger the load for the next dO values. // Trigger the load for the next dO values.
if( l < steps - 1) { if (l + step_stride < steps) {
smem_do.move_to_next_write_buffer(); smem_do.move_to_next_write_buffer();
gmem_do.move(); gmem_do.move(step_stride);
gmem_do.load(); gmem_do.load();
if (Is_first) { if (Is_first) {
gmem_o.move(); gmem_o.move(step_stride);
gmem_o.load(); gmem_o.load();
} }
} }
...@@ -443,7 +449,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -443,7 +449,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_dp.store(frag_p); smem_dp.store(frag_p);
// gmem_s.store(frag_p, mask); // gmem_s.store(frag_p, mask);
// gmem_s.move(); // gmem_s.move(step_stride);
// Declare the accumulators for the 2nd gemm. // Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N]; fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
...@@ -520,7 +526,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -520,7 +526,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// } // }
// __syncthreads(); // __syncthreads();
// Commit the values for Q and dO into shared memory. // Commit the values for Q and dO into shared memory.
if(l < steps - 1) { if (l + step_stride < steps) {
gmem_q.commit(gemm_q_k.smem_q); gmem_q.commit(gemm_q_k.smem_q);
} }
...@@ -529,15 +535,16 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -529,15 +535,16 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// __syncthreads(); // __syncthreads();
// Commit the values for Q and dO into shared memory. // Commit the values for Q and dO into shared memory.
if(l < steps - 1) { if (l + step_stride < steps) {
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
gmem_softmax_d.move(step_stride);
if (Is_first) { if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>( dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
); );
} }
gmem_softmax_lse.move(step_stride);
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse)); gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
} }
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
...@@ -567,9 +574,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -567,9 +574,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Make sure dQ is in shared memory. // Make sure dQ is in shared memory.
__syncthreads(); __syncthreads();
if (l < steps - 1) { if (l + step_stride < steps) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum)); gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
} }
// Load from shared memory. // Load from shared memory.
...@@ -590,20 +596,20 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -590,20 +596,20 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Output the values. // Output the values.
gmem_dq.template store<elem_type>(dq_out, 0); gmem_dq.template store<elem_type>(dq_out, 0);
// Move to the next part of the output. // Move to the next part of the output.
gmem_dq.move(); gmem_dq.move(step_stride);
} else { } else {
// Output the values. // Output the values.
gmem_dq_tmp.store(dq_out, 0); gmem_dq_tmp.store(dq_out, 0);
} }
// Move to the next part of the output. // Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); } if (!(Is_first && Is_last)) { gmem_dq_tmp.move(step_stride); }
// // Make sure the data is in shared memory. // // Make sure the data is in shared memory.
// __syncthreads(); // __syncthreads();
// Commit the values for Q and dO into shared memory. // Commit the values for Q and dO into shared memory.
if(l < steps - 1) { if (l + step_stride < steps) {
gemm_q_k.smem_q.move_to_next_read_buffer(); gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q(); gemm_q_k.reload_q();
smem_qt.move_to_next_read_buffer(); smem_qt.move_to_next_read_buffer();
...@@ -652,18 +658,34 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -652,18 +658,34 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
uint4 dv_out[Smem_tile_dv::NUM_LDS]; uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out); smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false); Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false);
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
if (!Is_first) { if (!Is_first) {
gmem_dv.move(loop_step_idx); gmem_dv.move(loop_step_idx);
// gmem_dv_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dv.store(dv_out);
} else {
gmem_dv.template atomic_add<elem_type>(dv_out);
// gmem_dv_accum.atomic_add_float(dv_out);
} }
gmem_dv.store(dv_out);
uint4 dk_out[Smem_tile_dk::NUM_LDS]; uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out); smem_dk.load(dk_out);
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false); Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
// gmem_dk_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dk.store(dk_out);
} else {
gmem_dk.template atomic_add<elem_type>(dk_out);
// gmem_dk_accum.atomic_add_float(dk_out);
} }
gmem_dk.store(dk_out);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -162,40 +162,5 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) { ...@@ -162,40 +162,5 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } // }
// if (launch_params.params.d == 64) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if( launch_params.params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
// }
// if (launch_params.params.d == 128) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
}); });
} }
\ No newline at end of file
...@@ -7,10 +7,7 @@ import flash_attn_cuda ...@@ -7,10 +7,7 @@ import flash_attn_cuda
def _get_block_size(device, head_dim, is_dropout): def _get_block_size(device, head_dim, is_dropout):
assert head_dim in [16, 32, 64, 128] assert head_dim in [16, 32, 64, 128]
if head_dim in [16, 32, 64]: return 256 if head_dim in [16, 32, 64] else 128
return 256
elif head_dim == 128:
return 256 if (torch.cuda.get_device_capability(device) == (8, 0)) else 128
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
...@@ -32,11 +29,17 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ...@@ -32,11 +29,17 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0,
generator=None): generator=None):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. Setting this too large (e.g. > 10) could make
numerical error of dK and dV larger (scaling as sqrt(num_splits)).
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
softmax_d = flash_attn_cuda.bwd( softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, generator) max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint() # breakpoint()
return dq, dk, dv, softmax_d return dq, dk, dv, softmax_d
......
...@@ -356,7 +356,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -356,7 +356,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 32 # Set smaller batch size so it would trigger num_splits > 1
batch_size = 8
nheads = 4 nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
...@@ -418,10 +419,11 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -418,10 +419,11 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if dropout_p == 0.0: if dropout_p == 0.0:
assert dropout_mask.all() assert dropout_mask.all()
else: else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01 assert 0.98 <= dropout_fraction / dropout_p <= 1.02
if is_sm80 or d < 128: # Only run backward for d=128 on A100 if is_sm80 or d < 128: # Only run backward for d=128 on A100
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() # Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension
assert (dqkv - dqkv_ref).abs().max().item() <= 4 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment