Commit 55778193 authored by Tri Dao's avatar Tri Dao
Browse files

Parallelize CUDA bwd along seqlen_k instead of seqlen_q

This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
parent ca81f32e
...@@ -454,17 +454,13 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -454,17 +454,13 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
launch(params, stream, /*configure=*/true); launch(params, stream, /*configure=*/true);
at::Tensor dk_accum, dv_accum;
if (params.num_splits > 1) { if (params.num_splits > 1) {
// dk_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat)); if (!dq_tmp.defined()) {
// dv_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat)); dq_tmp = torch::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
// params.dk_accum_ptr = dk_accum.data_ptr(); params.o_tmp_ptr = dq_tmp.data_ptr(); // o_tmp stores dq_tmp in the backward pass
// params.dv_accum_ptr = dv_accum.data_ptr(); } else {
dk.zero_(); dq_tmp.zero_();
dv.zero_(); }
} else {
// params.dk_accum_ptr = nullptr;
// params.dv_accum_ptr = nullptr;
} }
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
...@@ -481,10 +477,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -481,10 +477,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
launch(params, stream, /*configure=*/false); launch(params, stream, /*configure=*/false);
// if (params.num_splits > 1) { if (params.num_splits > 1) {
// dk.copy_(dk_accum); dq.copy_(dq_tmp);
// dv.copy_(dv_accum); }
// }
return { dq, dk, dv, softmax_d }; return { dq, dk, dv, softmax_d };
} }
......
...@@ -34,20 +34,6 @@ ...@@ -34,20 +34,6 @@
namespace fmha { namespace fmha {
// template <typename half2_t>
// inline __device__ void atomic_add_CAS(half2_t *address, const half2_t val) {
// uint32_t *address_as_ui = (uint32_t *)address;
// uint32_t old = *address_as_ui;
// uint32_t assumed;
// do {
// assumed = old;
// half2_t sum = __hadd2(val, reinterpret_cast<half2_t(&)>(old));
// old = atomicCAS(address_as_ui, assumed, reinterpret_cast<uint32_t(&)>(sum));
// } while (assumed != old);
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< template<
// The dimensions of the tile computed by the CTA. // The dimensions of the tile computed by the CTA.
typename Cta_tile_, typename Cta_tile_,
...@@ -148,43 +134,6 @@ struct Gmem_tile_qkv { ...@@ -148,43 +134,6 @@ struct Gmem_tile_qkv {
} }
} }
template <typename elem_type>
inline __device__ void atomic_add(const uint4 (&data)[LDGS]) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
// atomic_add_CAS(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
}
}
}
}
// Not being used. This only supports converting from fp16 -> fp32 for now (not bf16 -> fp32).
inline __device__ void atomic_add_float(const uint4 (&data)[LDGS]) {
static_assert(BYTES_PER_ELEMENT == 4); // Only support fp32
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]);
atomicAdd(ptr_ + jj * 2, data_f.x);
atomicAdd(ptr_ + jj * 2 + 1, data_f.y);
}
}
}
}
inline __device__ void move(const int steps = 1) { inline __device__ void move(const int steps = 1) {
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps; // ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps; ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
...@@ -306,6 +255,27 @@ struct Gmem_tile_o { ...@@ -306,6 +255,27 @@ struct Gmem_tile_o {
} }
} }
// Store data to global memory with atomicAdd.
inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
float *ptr_ = reinterpret_cast<float *>(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const float(&)[4]>(src[ii])[jj]);
}
}
}
}
// Load data from global memory. // Load data from global memory.
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4); static_assert(BYTES_PER_ELEMENT == 4);
......
...@@ -6,36 +6,31 @@ ...@@ -6,36 +6,31 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h" #include "fmha_dgrad_kernel_1xN_loop.h"
// Find the number of splits that maximizes the occupancy. For example, if we have // Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is // Parallelizing will have better occupancy, but has some overhead due to having to zero out dq
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many // dq_tmp and having to copy dq_tmp to dq.
// splits as that would incur more HBM reads/writes. int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
// Moreover, more than 1 split incurs extra cost of zeroing out dk/dv and doing atomic add int blocksize, bool is_causal) {
// instead of just writing. float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm);
// So for num_splits > 1, we divide the efficiency by some factor (e.g. 1.25, depending on seqlen) float eff_1 = n_waves_1 / ceil(n_waves_1);
// to account for this. Moreover, more splits means atomic add will be slower. int num_splits_parallel = seqlen / blocksize;
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits, float n_waves_parallel = float(batch_nheads * num_splits_parallel) / (num_SMs * ctas_per_sm);
int seqlen, bool is_causal) { float eff_parallel_raw = n_waves_parallel / ceil(n_waves_parallel);
float max_efficiency = 0.f; float discount_factor;
int best_num_splits = 1; if (!is_causal) {
std::vector<float> efficiency; discount_factor = 1.f + float(blocksize) / seqlen;
efficiency.reserve(max_splits); } else { // For causal, parallelizing seems to help with load-balancing as well
float discount_factor = 1.f + 512.0 / seqlen; // 1.25 for seqlen 2k, 1.125 for 4k. // For example, if headdim=128, seqlen >= 1280 always prefers parallel
discount_factor *= is_causal ? 1.1 : 1.f; // causal makes it even slower. if (seqlen / blocksize >= 10) return num_splits_parallel;
for (int num_splits = 1; num_splits <= max_splits; num_splits++) { discount_factor = 1.f + 0.5 * float(blocksize) / seqlen;
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff_raw = n_waves / ceil(n_waves);
// Heuristic: each increase in num_splits results in 6% slowdown, up to maybe 8 splits.
float eff = num_splits == 1 ? eff_raw : (eff_raw - 0.07 * std::min(num_splits - 2, 6)) / discount_factor;
// printf("num_splits = %d, eff_raw = %f, eff = %f\n", num_splits, eff_raw, eff);
if (eff > max_efficiency) {
max_efficiency = eff;
best_num_splits = num_splits;
}
efficiency.push_back(eff);
} }
// printf("num_splits chosen = %d\n", best_num_splits); float eff_parallel = eff_parallel_raw / discount_factor;
return best_num_splits; return eff_1 >= eff_parallel ? 1 : num_splits_parallel;
}
template<typename Kernel_traits>
__global__ void fmha_dgrad_dot_do_o_kernel(FMHA_dgrad_params params) {
fmha::compute_dot_do_o<Kernel_traits>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
...@@ -43,6 +38,11 @@ __global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params para ...@@ -43,6 +38,11 @@ __global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params para
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params); fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
}
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) { void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
...@@ -74,9 +74,14 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre ...@@ -74,9 +74,14 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2> ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>; : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
} }
auto kernel_seqparallel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
if( smem_size_dq_dk_dv >= 48 * 1024 ) { if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel_seqparallel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
} }
// Automatically set num_splits to maximize occupancy // Automatically set num_splits to maximize occupancy
if (params.num_splits <= 0) { if (params.num_splits <= 0) {
...@@ -90,13 +95,20 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre ...@@ -90,13 +95,20 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
// Numerical error on dk/dv scales as sqrt(num_splits). // Numerical error on dk/dv scales as sqrt(num_splits).
params.num_splits = num_splits_heuristic_bwd( params.num_splits = num_splits_heuristic_bwd(
params.b * params.h, dprops->multiProcessorCount, params.b * params.h, dprops->multiProcessorCount,
ctas_per_sm, /*max_splits=*/std::min(10, (params.seqlen_q + M - 1 / M)), ctas_per_sm, params.seqlen_k, blocksize_c, params.is_causal
params.seqlen_k, params.is_causal
); );
} }
if (configure) return; if (configure) return;
dim3 grid(params.b, params.h, params.num_splits); if (params.num_splits == 1) {
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params); dim3 grid(params.b, params.h, params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
} else {
dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
fmha_dgrad_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c
dim3 grid(params.b, params.h, num_splits);
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
}
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
}); });
} }
......
...@@ -31,7 +31,86 @@ inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], cons ...@@ -31,7 +31,86 @@ inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], cons
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, typename Params, typename Prng> // Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 3rd batched GEMM.
using Cta_tile_dkv =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
static_assert(Cta_tile_dkv::K == 16);
// The global memory tile to load dO.
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// The global memory tile to load O.Loading O here is similar to loading dO.
using Gmem_tile_o = Gmem_tile_do;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
// How many steps to jump per iteration.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() ) return;
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M;
// Wind gmem tiles to the correct position.
gmem_do.move(blockIdx.z);
gmem_o.move(blockIdx.z);
gmem_softmax_d.move(blockIdx.z);
// Load over the entire sequence length.
for (int l = blockIdx.z; l < steps; l += step_stride) {
if (l * Cta_tile_p::M >= binfo.actual_seqlen_q)
break;
gmem_do.load();
gmem_do.move(step_stride);
gmem_o.load();
gmem_o.move(step_stride);
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
gmem_softmax_d.move(step_stride);
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params, typename Prng>
inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph, inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph,
const int loop_step_idx) { const int loop_step_idx) {
...@@ -135,9 +214,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -135,9 +214,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
// How many steps to jump per iteration, which is the same as params.num_splits.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx); const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return; // if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
...@@ -195,23 +271,16 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -195,23 +271,16 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin; const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
// Wind gmem tiles to the correct position. // Wind gmem tiles to the correct position.
gmem_q.move(begin + blockIdx.z); gmem_q.move(begin);
gmem_do.move(begin + blockIdx.z); gmem_do.move(begin);
gmem_o.move(begin + blockIdx.z); gmem_o.move(begin);
gmem_dq.move(begin + blockIdx.z); if (!Seq_parallel) { gmem_dq.move(begin); } // If Seq_parallel, we're not using gmem_dq at all
gmem_dq_tmp.move(begin + blockIdx.z); gmem_dq_tmp.move(begin);
// TODO: need to move gmem_s if we want the intermediate result for debugging // TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse.move(begin + blockIdx.z); gmem_softmax_lse.move(begin);
gmem_softmax_d.move(begin + blockIdx.z); gmem_softmax_d.move(begin);
if (!Is_first) { if (!Is_first) {
gmem_k.move(loop_step_idx); gmem_k.move(loop_step_idx);
...@@ -315,7 +384,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -315,7 +384,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk); fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length. // Load over the entire sequence length.
for (int l = blockIdx.z; l < steps; l += step_stride) { for (int l = 0; l < steps; l++) {
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q)
break; break;
...@@ -365,9 +434,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -365,9 +434,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_s.store(frag_p); smem_s.store(frag_p);
// Trigger the load for the next Q values. // Trigger the load for the next Q values.
if (l + step_stride < steps) { if (l + 1 < steps) {
gemm_q_k.smem_q.move_to_next_write_buffer(); gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move(step_stride); gmem_q.move();
gmem_q.load(); gmem_q.load();
} }
...@@ -440,12 +509,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -440,12 +509,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_kt.load(frag_kt[0], 0); smem_kt.load(frag_kt[0], 0);
// Trigger the load for the next dO values. // Trigger the load for the next dO values.
if (l + step_stride < steps) { if (l + 1 < steps) {
smem_do.move_to_next_write_buffer(); smem_do.move_to_next_write_buffer();
gmem_do.move(step_stride); gmem_do.move();
gmem_do.load(); gmem_do.load();
if (Is_first) { if (Is_first) {
gmem_o.move(step_stride); gmem_o.move();
gmem_o.load(); gmem_o.load();
} }
} }
...@@ -456,7 +525,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -456,7 +525,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_dp.store(frag_p); smem_dp.store(frag_p);
// gmem_s.store(frag_p, mask); // gmem_s.store(frag_p, mask);
// gmem_s.move(step_stride); // gmem_s.move();
// Declare the accumulators for the 2nd gemm. // Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N]; fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
...@@ -533,24 +602,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -533,24 +602,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// } // }
// __syncthreads(); // __syncthreads();
// Commit the values for Q and dO into shared memory. // Commit the values for Q and dO into shared memory.
if (l + step_stride < steps) { if (l + 1 < steps) {
gmem_q.commit(gemm_q_k.smem_q); gmem_q.commit(gemm_q_k.smem_q);
} }
uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP]; uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP];
if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); } if (!Is_first && !Seq_parallel) { gmem_dq_tmp.load(dq_out, 0); }
// __syncthreads(); // __syncthreads();
// Commit the values for Q and dO into shared memory. // Commit the values for Q and dO into shared memory.
if (l + step_stride < steps) { if (l + 1 < steps) {
gmem_do.commit(smem_do); gmem_do.commit(smem_do);
gmem_softmax_d.move(step_stride); gmem_softmax_d.move();
if (Is_first) { if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>( dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
); );
} }
gmem_softmax_lse.move(step_stride); gmem_softmax_lse.move();
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse)); gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
} }
...@@ -581,42 +650,53 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -581,42 +650,53 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Make sure dQ is in shared memory. // Make sure dQ is in shared memory.
__syncthreads(); __syncthreads();
if (l + step_stride < steps) { if (l + 1 < steps) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum)); gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
} }
// Load from shared memory. // Load from shared memory.
smem_dq.template load</*zero_init=*/Is_first>(dq_out); smem_dq.template load</*zero_init=*/Is_first || Seq_parallel>(dq_out);
const bool is_final_write = if (!Seq_parallel) {
Is_last const bool is_final_write =
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k) Is_last
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
if (is_final_write) { || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
// if (Is_dropout) { if (is_final_write) {
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); // if (Is_dropout) {
// } // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
// }
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
}
// Output the values.
gmem_dq.template store<elem_type>(dq_out, 0);
// Move to the next part of the output.
gmem_dq.move();
// TODO: for parallel, need to deal with the dropout scaling
} else {
// Output the values.
gmem_dq_tmp.store(dq_out, 0);
}
} else {
// We always scale dq_out before writing in this case, since we don't want to
// have to scale at the end when copying from dq_tmp to dq.
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); // dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout); dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
} }
// Output the values. gmem_dq_tmp.atomic_add(dq_out, 0);
gmem_dq.template store<elem_type>(dq_out, 0);
// Move to the next part of the output.
gmem_dq.move(step_stride);
} else {
// Output the values.
gmem_dq_tmp.store(dq_out, 0);
} }
// Move to the next part of the output. // Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(step_stride); } if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); }
// // Make sure the data is in shared memory. // // Make sure the data is in shared memory.
// __syncthreads(); // __syncthreads();
// Commit the values for Q and dO into shared memory. // Commit the values for Q and dO into shared memory.
if (l + step_stride < steps) { if (l + 1 < steps) {
gemm_q_k.smem_q.move_to_next_read_buffer(); gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q(); gemm_q_k.reload_q();
smem_qt.move_to_next_read_buffer(); smem_qt.move_to_next_read_buffer();
...@@ -666,35 +746,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -666,35 +746,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_dv.load(dv_out); smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
params.d, binfo, tidx, false); params.d, binfo, tidx, false);
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
if (!Is_first) { if (!Is_first) {
gmem_dv.move(loop_step_idx); gmem_dv.move(loop_step_idx);
// gmem_dv_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dv.store(dv_out);
} else {
gmem_dv.template atomic_add<elem_type>(dv_out);
// gmem_dv_accum.atomic_add_float(dv_out);
} }
gmem_dv.store(dv_out);
uint4 dk_out[Smem_tile_dk::NUM_LDS]; uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out); smem_dk.load(dk_out);
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
params.d, binfo, tidx, false); params.d, binfo, tidx, false);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
if (!Is_first) { if (!Is_first) {
gmem_dk.move(loop_step_idx); gmem_dk.move(loop_step_idx);
// gmem_dk_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dk.store(dk_out);
} else {
gmem_dk.template atomic_add<elem_type>(dk_out);
// gmem_dk_accum.atomic_add_float(dk_out);
} }
gmem_dk.store(dk_out);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -736,4 +800,22 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) { ...@@ -736,4 +800,22 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, typename Params>
inline __device__ void compute_dq_dk_dv_seqparallel(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
int loop_step_idx = blockIdx.z;
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false, /*Seq_parallel=*/true>(params, ph, loop_step_idx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha } // namespace fmha
...@@ -32,12 +32,13 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens ...@@ -32,12 +32,13 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0,
generator=None): generator=None):
""" """
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
it will be set by an internal heuristic. Setting this too large (e.g. > 10) could make not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
numerical error of dK and dV larger (scaling as sqrt(num_splits)). Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
""" """
softmax_d = flash_attn_cuda.bwd( _, _, _, softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator) max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
......
...@@ -26,10 +26,11 @@ that there are none left for other head dimensions. ...@@ -26,10 +26,11 @@ that there are none left for other head dimensions.
Differences between this Triton version and the CUDA version: Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout. - Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward. - Triton forward is generally faster than CUDA forward, while Triton backward is
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
It is slightly slower when headdim=128 and batch * nheads is large. than CUDA forward + backward.
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton version supports attention bias, while CUDA version doesn't.
""" """
import math import math
......
...@@ -368,8 +368,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): ...@@ -368,8 +368,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random') key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment