Commit 55778193 authored by Tri Dao's avatar Tri Dao
Browse files

Parallelize CUDA bwd along seqlen_k instead of seqlen_q

This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
parent ca81f32e
......@@ -454,17 +454,13 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
launch(params, stream, /*configure=*/true);
at::Tensor dk_accum, dv_accum;
if (params.num_splits > 1) {
// dk_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// dv_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// params.dk_accum_ptr = dk_accum.data_ptr();
// params.dv_accum_ptr = dv_accum.data_ptr();
dk.zero_();
dv.zero_();
if (!dq_tmp.defined()) {
dq_tmp = torch::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
params.o_tmp_ptr = dq_tmp.data_ptr(); // o_tmp stores dq_tmp in the backward pass
} else {
// params.dk_accum_ptr = nullptr;
// params.dv_accum_ptr = nullptr;
dq_tmp.zero_();
}
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
......@@ -481,10 +477,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
launch(params, stream, /*configure=*/false);
// if (params.num_splits > 1) {
// dk.copy_(dk_accum);
// dv.copy_(dv_accum);
// }
if (params.num_splits > 1) {
dq.copy_(dq_tmp);
}
return { dq, dk, dv, softmax_d };
}
......
......@@ -34,20 +34,6 @@
namespace fmha {
// template <typename half2_t>
// inline __device__ void atomic_add_CAS(half2_t *address, const half2_t val) {
// uint32_t *address_as_ui = (uint32_t *)address;
// uint32_t old = *address_as_ui;
// uint32_t assumed;
// do {
// assumed = old;
// half2_t sum = __hadd2(val, reinterpret_cast<half2_t(&)>(old));
// old = atomicCAS(address_as_ui, assumed, reinterpret_cast<uint32_t(&)>(sum));
// } while (assumed != old);
// }
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile_,
......@@ -148,43 +134,6 @@ struct Gmem_tile_qkv {
}
}
template <typename elem_type>
inline __device__ void atomic_add(const uint4 (&data)[LDGS]) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
// atomic_add_CAS(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
}
}
}
}
// Not being used. This only supports converting from fp16 -> fp32 for now (not bf16 -> fp32).
inline __device__ void atomic_add_float(const uint4 (&data)[LDGS]) {
static_assert(BYTES_PER_ELEMENT == 4); // Only support fp32
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]);
atomicAdd(ptr_ + jj * 2, data_f.x);
atomicAdd(ptr_ + jj * 2 + 1, data_f.y);
}
}
}
}
inline __device__ void move(const int steps = 1) {
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
......@@ -306,6 +255,27 @@ struct Gmem_tile_o {
}
}
// Store data to global memory with atomicAdd.
inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
break;
}
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
float *ptr_ = reinterpret_cast<float *>(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const float(&)[4]>(src[ii])[jj]);
}
}
}
}
// Load data from global memory.
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4);
......
......@@ -6,36 +6,31 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// Moreover, more than 1 split incurs extra cost of zeroing out dk/dv and doing atomic add
// instead of just writing.
// So for num_splits > 1, we divide the efficiency by some factor (e.g. 1.25, depending on seqlen)
// to account for this. Moreover, more splits means atomic add will be slower.
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits,
int seqlen, bool is_causal) {
float max_efficiency = 0.f;
int best_num_splits = 1;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
float discount_factor = 1.f + 512.0 / seqlen; // 1.25 for seqlen 2k, 1.125 for 4k.
discount_factor *= is_causal ? 1.1 : 1.f; // causal makes it even slower.
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff_raw = n_waves / ceil(n_waves);
// Heuristic: each increase in num_splits results in 6% slowdown, up to maybe 8 splits.
float eff = num_splits == 1 ? eff_raw : (eff_raw - 0.07 * std::min(num_splits - 2, 6)) / discount_factor;
// printf("num_splits = %d, eff_raw = %f, eff = %f\n", num_splits, eff_raw, eff);
if (eff > max_efficiency) {
max_efficiency = eff;
best_num_splits = num_splits;
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
// Parallelizing will have better occupancy, but has some overhead due to having to zero out dq
// dq_tmp and having to copy dq_tmp to dq.
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
int blocksize, bool is_causal) {
float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm);
float eff_1 = n_waves_1 / ceil(n_waves_1);
int num_splits_parallel = seqlen / blocksize;
float n_waves_parallel = float(batch_nheads * num_splits_parallel) / (num_SMs * ctas_per_sm);
float eff_parallel_raw = n_waves_parallel / ceil(n_waves_parallel);
float discount_factor;
if (!is_causal) {
discount_factor = 1.f + float(blocksize) / seqlen;
} else { // For causal, parallelizing seems to help with load-balancing as well
// For example, if headdim=128, seqlen >= 1280 always prefers parallel
if (seqlen / blocksize >= 10) return num_splits_parallel;
discount_factor = 1.f + 0.5 * float(blocksize) / seqlen;
}
efficiency.push_back(eff);
}
// printf("num_splits chosen = %d\n", best_num_splits);
return best_num_splits;
float eff_parallel = eff_parallel_raw / discount_factor;
return eff_1 >= eff_parallel ? 1 : num_splits_parallel;
}
template<typename Kernel_traits>
__global__ void fmha_dgrad_dot_do_o_kernel(FMHA_dgrad_params params) {
fmha::compute_dot_do_o<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
......@@ -43,6 +38,11 @@ __global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params para
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
}
template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
......@@ -74,9 +74,14 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
}
auto kernel_seqparallel = params.is_causal
? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
: &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel_seqparallel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
// Automatically set num_splits to maximize occupancy
if (params.num_splits <= 0) {
......@@ -90,13 +95,20 @@ void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stre
// Numerical error on dk/dv scales as sqrt(num_splits).
params.num_splits = num_splits_heuristic_bwd(
params.b * params.h, dprops->multiProcessorCount,
ctas_per_sm, /*max_splits=*/std::min(10, (params.seqlen_q + M - 1 / M)),
params.seqlen_k, params.is_causal
ctas_per_sm, params.seqlen_k, blocksize_c, params.is_causal
);
}
if (configure) return;
if (params.num_splits == 1) {
dim3 grid(params.b, params.h, params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
} else {
dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
fmha_dgrad_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c
dim3 grid(params.b, params.h, num_splits);
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
}
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}
......
......@@ -31,7 +31,86 @@ inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], cons
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, typename Params, typename Prng>
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
#else
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
assert(is_fp16_type);
using elem_type = __half;
#endif
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 3rd batched GEMM.
using Cta_tile_dkv =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
static_assert(Cta_tile_dkv::K == 16);
// The global memory tile to load dO.
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// The global memory tile to load O.Loading O here is similar to loading dO.
using Gmem_tile_o = Gmem_tile_do;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
// How many steps to jump per iteration.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() ) return;
// Allocate the global memory tile loader for dO.
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
params.d, binfo, tidx, true);
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M;
// Wind gmem tiles to the correct position.
gmem_do.move(blockIdx.z);
gmem_o.move(blockIdx.z);
gmem_softmax_d.move(blockIdx.z);
// Load over the entire sequence length.
for (int l = blockIdx.z; l < steps; l += step_stride) {
if (l * Cta_tile_p::M >= binfo.actual_seqlen_q)
break;
gmem_do.load();
gmem_do.move(step_stride);
gmem_o.load();
gmem_o.move(step_stride);
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
gmem_softmax_d.move(step_stride);
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params, typename Prng>
inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng &ph,
const int loop_step_idx) {
......@@ -135,9 +214,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// The thread index.
const int tidx = threadIdx.x;
// How many steps to jump per iteration, which is the same as params.num_splits.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
......@@ -195,23 +271,16 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
// Wind gmem tiles to the correct position.
gmem_q.move(begin + blockIdx.z);
gmem_do.move(begin + blockIdx.z);
gmem_o.move(begin + blockIdx.z);
gmem_dq.move(begin + blockIdx.z);
gmem_dq_tmp.move(begin + blockIdx.z);
gmem_q.move(begin);
gmem_do.move(begin);
gmem_o.move(begin);
if (!Seq_parallel) { gmem_dq.move(begin); } // If Seq_parallel, we're not using gmem_dq at all
gmem_dq_tmp.move(begin);
// TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse.move(begin + blockIdx.z);
gmem_softmax_d.move(begin + blockIdx.z);
gmem_softmax_lse.move(begin);
gmem_softmax_d.move(begin);
if (!Is_first) {
gmem_k.move(loop_step_idx);
......@@ -315,7 +384,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for (int l = blockIdx.z; l < steps; l += step_stride) {
for (int l = 0; l < steps; l++) {
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q)
break;
......@@ -365,9 +434,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_s.store(frag_p);
// Trigger the load for the next Q values.
if (l + step_stride < steps) {
if (l + 1 < steps) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move(step_stride);
gmem_q.move();
gmem_q.load();
}
......@@ -440,12 +509,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_kt.load(frag_kt[0], 0);
// Trigger the load for the next dO values.
if (l + step_stride < steps) {
if (l + 1 < steps) {
smem_do.move_to_next_write_buffer();
gmem_do.move(step_stride);
gmem_do.move();
gmem_do.load();
if (Is_first) {
gmem_o.move(step_stride);
gmem_o.move();
gmem_o.load();
}
}
......@@ -456,7 +525,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_dp.store(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move(step_stride);
// gmem_s.move();
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
......@@ -533,24 +602,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// }
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if (l + step_stride < steps) {
if (l + 1 < steps) {
gmem_q.commit(gemm_q_k.smem_q);
}
uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP];
if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); }
if (!Is_first && !Seq_parallel) { gmem_dq_tmp.load(dq_out, 0); }
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if (l + step_stride < steps) {
if (l + 1 < steps) {
gmem_do.commit(smem_do);
gmem_softmax_d.move(step_stride);
gmem_softmax_d.move();
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
gmem_softmax_lse.move(step_stride);
gmem_softmax_lse.move();
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
}
......@@ -581,13 +650,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Make sure dQ is in shared memory.
__syncthreads();
if (l + step_stride < steps) {
if (l + 1 < steps) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
}
// Load from shared memory.
smem_dq.template load</*zero_init=*/Is_first>(dq_out);
smem_dq.template load</*zero_init=*/Is_first || Seq_parallel>(dq_out);
if (!Seq_parallel) {
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
......@@ -603,20 +673,30 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Output the values.
gmem_dq.template store<elem_type>(dq_out, 0);
// Move to the next part of the output.
gmem_dq.move(step_stride);
gmem_dq.move();
// TODO: for parallel, need to deal with the dropout scaling
} else {
// Output the values.
gmem_dq_tmp.store(dq_out, 0);
}
} else {
// We always scale dq_out before writing in this case, since we don't want to
// have to scale at the end when copying from dq_tmp to dq.
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
}
gmem_dq_tmp.atomic_add(dq_out, 0);
}
// Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(step_stride); }
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); }
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if (l + step_stride < steps) {
if (l + 1 < steps) {
gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q();
smem_qt.move_to_next_read_buffer();
......@@ -666,35 +746,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
params.d, binfo, tidx, false);
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
if (!Is_first) {
gmem_dv.move(loop_step_idx);
// gmem_dv_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dv.store(dv_out);
} else {
gmem_dv.template atomic_add<elem_type>(dv_out);
// gmem_dv_accum.atomic_add_float(dv_out);
}
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
params.d, binfo, tidx, false);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
if (!Is_first) {
gmem_dk.move(loop_step_idx);
// gmem_dk_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dk.store(dk_out);
} else {
gmem_dk.template atomic_add<elem_type>(dk_out);
// gmem_dk_accum.atomic_add_float(dk_out);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -736,4 +800,22 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, typename Params>
inline __device__ void compute_dq_dk_dv_seqparallel(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
auto seeds = at::cuda::philox::unpack(params.philox_args);
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
int loop_step_idx = blockIdx.z;
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false, /*Seq_parallel=*/true>(params, ph, loop_step_idx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
......@@ -32,12 +32,13 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0,
generator=None):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. Setting this too large (e.g. > 10) could make
numerical error of dK and dV larger (scaling as sqrt(num_splits)).
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
softmax_d = flash_attn_cuda.bwd(
_, _, _, softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
......
......@@ -26,10 +26,11 @@ that there are none left for other head dimensions.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
It is slightly slower when headdim=128 and batch * nheads is large.
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton forward is generally faster than CUDA forward, while Triton backward is
generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
than CUDA forward + backward.
- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton version supports attention bias, while CUDA version doesn't.
"""
import math
......
......@@ -368,8 +368,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='random')
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment