Commit a5a8806d authored by Tri Dao's avatar Tri Dao
Browse files

Split bwd on the seqlen_q dimension

parent 871db479
......@@ -241,7 +241,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
int blocksize_c = (head_size == 128 && (!is_sm80)) ? 128 : 256;
int blocksize_c = head_size == 128 ? 128 : 256;
// Need to round max_seqlen_k to multiples of blocksize_c
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
if( max_seqlen_k_ <= 128 ) {
......@@ -332,6 +332,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const int num_splits,
c10::optional<at::Generator> gen_
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
......@@ -447,7 +448,22 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
p_dropout,
softmax_scale,
is_causal,
/*num_splits=*/1);
num_splits);
launch(params, stream, /*configure=*/true);
at::Tensor dk_accum, dv_accum;
if (params.num_splits > 1) {
// dk_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// dv_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
// params.dk_accum_ptr = dk_accum.data_ptr();
// params.dv_accum_ptr = dv_accum.data_ptr();
dk.zero_();
dv.zero_();
} else {
// params.dk_accum_ptr = nullptr;
// params.dv_accum_ptr = nullptr;
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
......@@ -461,7 +477,12 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
params.philox_args = gen->philox_cuda_state(counter_offset);
}
launch(params, stream);
launch(params, stream, /*configure=*/false);
// if (params.num_splits > 1) {
// dk.copy_(dk_accum);
// dv.copy_(dv_accum);
// }
return { dq, dk, dv, softmax_d };
}
......
......@@ -140,6 +140,10 @@ struct FMHA_dgrad_params : public FMHA_fprop_params {
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension
// void *__restrict__ dk_accum_ptr;
// void *__restrict__ dv_accum_ptr;
// The stride between rows of the dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
......@@ -193,7 +197,7 @@ struct Launch_params{
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
......
......@@ -28,6 +28,9 @@
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <fmha/utils.h>
namespace fmha {
......@@ -41,7 +44,8 @@ template<
// The number of rows of Q, K or V loaded by this tile.
int ROWS_,
// The number of columns.
int COLS
int COLS,
int BYTES_PER_LDGS_ = 16
>
struct Gmem_tile_qkv {
......@@ -49,7 +53,7 @@ struct Gmem_tile_qkv {
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
// The size of each LDG.
static constexpr int BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_;
// The size of a row in bytes.
static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;
......@@ -130,6 +134,42 @@ struct Gmem_tile_qkv {
}
}
template <typename elem_type>
inline __device__ void atomic_add(const uint4 (&data)[LDGS]) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type;
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
}
}
}
}
// Not being used. This only supports converting from fp16 -> fp32 for now (not bf16 -> fp32).
inline __device__ void atomic_add_float(const uint4 (&data)[LDGS]) {
static_assert(BYTES_PER_ELEMENT == 4); // Only support fp32
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]);
atomicAdd(ptr_ + jj * 2, data_f.x);
atomicAdd(ptr_ + jj * 2 + 1, data_f.y);
}
}
}
}
inline __device__ void move(const int steps = 1) {
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
......
......@@ -76,6 +76,12 @@ struct FMHA_kernel_traits {
using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// // The global memory tile to store the accumulated dK and dV
// // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV
// // where there are 16 bits per lements and 16 bytes per load. In reality we won't
// // be issue any load or store of size 32 bytes.
// using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv<Cta_tile_o, 32, S, D, 32>;
// The global memory tile to store the softmax sum.
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
......
......@@ -6,13 +6,45 @@
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_loop.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// Moreover, more than 1 split incurs extra cost of zeroing out dk/dv and doing atomic add
// instead of just writing.
// So for num_splits > 1, we divide the efficiency by some factor (e.g. 1.25, depending on seqlen)
// to account for this. Moreover, more splits means atomic add will be slower.
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits,
int seqlen, bool is_causal) {
float max_efficiency = 0.f;
int best_num_splits = 1;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
float discount_factor = 1.f + 512.0 / seqlen; // 1.25 for seqlen 2k, 1.125 for 4k.
discount_factor *= is_causal ? 1.1 : 1.f; // causal makes it even slower.
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff_raw = n_waves / ceil(n_waves);
// Heuristic: each increase in num_splits results in 6% slowdown, up to maybe 8 splits.
float eff = num_splits == 1 ? eff_raw : (eff_raw - 0.07 * std::min(num_splits - 2, 6)) / discount_factor;
// printf("num_splits = %d, eff_raw = %f, eff = %f\n", num_splits, eff_raw, eff);
if (eff > max_efficiency) {
max_efficiency = eff;
best_num_splits = num_splits;
}
efficiency.push_back(eff);
}
// printf("num_splits chosen = %d\n", best_num_splits);
return best_num_splits;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
__global__ void fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
}
template<typename Kernel_traits>
void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_t stream) {
void run_fmha_dgrad_fp16_sm80_loop_(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
......@@ -46,41 +78,58 @@ void run_fmha_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params &params, cudaStream_
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
dim3 grid(params.b, params.h);
// Automatically set num_splits to maximize occupancy
if (params.num_splits <= 0) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size_dq_dk_dv);
auto dprops = at::cuda::getCurrentDeviceProperties();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr int M = Kernel_traits::Cta_tile_p::M;
// We don't want more than 10 splits due to numerical error.
// Numerical error on dk/dv scales as sqrt(num_splits).
params.num_splits = num_splits_heuristic_bwd(
params.b * params.h, dprops->multiProcessorCount,
ctas_per_sm, /*max_splits=*/std::min(10, (params.seqlen_q + M - 1 / M)),
params.seqlen_k, params.is_causal
);
}
if (configure) return;
dim3 grid(params.b, params.h, params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream) {
void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
// work around for MSVC issue
FP16_SWITCH(params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (params.d == 16) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else {
// TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 8, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
} else if (params.d == 32) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
} else if (params.d == 64) {
if( params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if( params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor == 0) {
// Don't share smem for K & V, and don't keep V in registers
......@@ -88,45 +137,18 @@ void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stre
// uses more shared memory, which is fine on A100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
}
} else if (params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream, configure);
}
// if (params.d == 64) {
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
// if (params.d == 128) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u_elem_type>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
});
}
\ No newline at end of file
......@@ -135,6 +135,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// The thread index.
const int tidx = threadIdx.x;
// How many steps to jump per iteration, which is the same as params.num_splits.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
......@@ -184,18 +187,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
// Wind gmem tiles to the correct position.
gmem_q.move(begin);
gmem_do.move(begin);
gmem_o.move(begin);
gmem_dq.move(begin);
gmem_dq_tmp.move(begin);
gmem_q.move(begin + blockIdx.z);
gmem_do.move(begin + blockIdx.z);
gmem_o.move(begin + blockIdx.z);
gmem_dq.move(begin + blockIdx.z);
gmem_dq_tmp.move(begin + blockIdx.z);
// TODO: need to move gmem_s if we want the intermediate result for debugging
gmem_softmax_lse.move(begin);
gmem_softmax_d.move(begin);
gmem_softmax_lse.move(begin + blockIdx.z);
gmem_softmax_d.move(begin + blockIdx.z);
if (!Is_first) {
gmem_k.move(loop_step_idx);
......@@ -215,7 +224,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
float p_lse[Mma_tile_p::MMAS_M * 2];
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
if (!Is_first) { __syncthreads(); }
// Commit the data for Q, dO, and V to shared memory.
......@@ -265,7 +273,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
float dp_sum[Mma_tile_p::MMAS_M * 2];
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
......@@ -301,9 +308,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
const int loop = (begin + l) * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen_q )
for (int l = blockIdx.z; l < steps; l += step_stride) {
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q)
break;
// Load the fragments for V.
......@@ -352,9 +358,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_s.store(frag_p);
// Trigger the load for the next Q values.
if( l < steps - 1) {
if (l + step_stride < steps) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.move(step_stride);
gmem_q.load();
}
......@@ -427,12 +433,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_kt.load(frag_kt[0], 0);
// Trigger the load for the next dO values.
if( l < steps - 1) {
if (l + step_stride < steps) {
smem_do.move_to_next_write_buffer();
gmem_do.move();
gmem_do.move(step_stride);
gmem_do.load();
if (Is_first) {
gmem_o.move();
gmem_o.move(step_stride);
gmem_o.load();
}
}
......@@ -443,7 +449,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_dp.store(frag_p);
// gmem_s.store(frag_p, mask);
// gmem_s.move();
// gmem_s.move(step_stride);
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
......@@ -520,7 +526,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// }
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < steps - 1) {
if (l + step_stride < steps) {
gmem_q.commit(gemm_q_k.smem_q);
}
......@@ -529,15 +535,16 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < steps - 1) {
if (l + step_stride < steps) {
gmem_do.commit(smem_do);
gmem_softmax_d.move(step_stride);
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
gmem_softmax_lse.move(step_stride);
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
}
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
......@@ -567,9 +574,8 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Make sure dQ is in shared memory.
__syncthreads();
if (l < steps - 1) {
if (l + step_stride < steps) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
// Load from shared memory.
......@@ -590,20 +596,20 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Output the values.
gmem_dq.template store<elem_type>(dq_out, 0);
// Move to the next part of the output.
gmem_dq.move();
gmem_dq.move(step_stride);
} else {
// Output the values.
gmem_dq_tmp.store(dq_out, 0);
}
// Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); }
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(step_stride); }
// // Make sure the data is in shared memory.
// __syncthreads();
// Commit the values for Q and dO into shared memory.
if(l < steps - 1) {
if (l + step_stride < steps) {
gemm_q_k.smem_q.move_to_next_read_buffer();
gemm_q_k.reload_q();
smem_qt.move_to_next_read_buffer();
......@@ -652,18 +658,34 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, binfo, tidx, false);
// using Gmem_tile_dkv_accum = typename Kernel_traits::Gmem_tile_dkv_accum;
// Gmem_tile_dkv_accum gmem_dv_accum(params.dv_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
// static_assert(Gmem_tile_dkv_accum::LDGS == Smem_tile_dv::NUM_LDS);
if (!Is_first) {
gmem_dv.move(loop_step_idx);
// gmem_dv_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dv.store(dv_out);
} else {
gmem_dv.template atomic_add<elem_type>(dv_out);
// gmem_dv_accum.atomic_add_float(dv_out);
}
gmem_dv.store(dv_out);
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false);
// Gmem_tile_dkv_accum gmem_dk_accum(params.dk_accum_ptr, params.h * params.d, params.d, binfo, tidx, false);
if (!Is_first) {
gmem_dk.move(loop_step_idx);
// gmem_dk_accum.move(loop_step_idx);
}
if (gridDim.z == 1) {
gmem_dk.store(dk_out);
} else {
gmem_dk.template atomic_add<elem_type>(dk_out);
// gmem_dk_accum.atomic_add_float(dk_out);
}
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -162,40 +162,5 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// if (launch_params.params.d == 64) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if( launch_params.params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
// }
// if (launch_params.params.d == 128) {
// if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else {
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// }
// }
// }
});
}
\ No newline at end of file
......@@ -7,10 +7,7 @@ import flash_attn_cuda
def _get_block_size(device, head_dim, is_dropout):
assert head_dim in [16, 32, 64, 128]
if head_dim in [16, 32, 64]:
return 256
elif head_dim == 128:
return 256 if (torch.cuda.get_device_capability(device) == (8, 0)) else 128
return 256 if head_dim in [16, 32, 64] else 128
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
......@@ -32,11 +29,17 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0,
generator=None):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. Setting this too large (e.g. > 10) could make
numerical error of dK and dV larger (scaling as sqrt(num_splits)).
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, generator)
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d
......
......@@ -356,7 +356,8 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 32
# Set smaller batch size so it would trigger num_splits > 1
batch_size = 8
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
......@@ -418,10 +419,11 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if dropout_p == 0.0:
assert dropout_mask.all()
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
assert 0.98 <= dropout_fraction / dropout_p <= 1.02
if is_sm80 or d < 128: # Only run backward for d=128 on A100
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
# Error for dK and dV could be a bit higher if we're splitting along seqlen_q dimension
assert (dqkv - dqkv_ref).abs().max().item() <= 4 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment