Commit a44f48df authored by Tri Dao's avatar Tri Dao
Browse files

Split fwd on the seqlen_q dimension

parent 1aa6d7d9
......@@ -54,7 +54,8 @@ void set_params_fprop(FMHA_fprop_params &params,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
bool is_causal,
int num_splits) {
Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
......@@ -117,6 +118,7 @@ void set_params_fprop(FMHA_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
params.is_causal = is_causal;
params.num_splits = num_splits;
}
void set_params_dgrad(FMHA_dgrad_params &params,
......@@ -142,7 +144,8 @@ void set_params_dgrad(FMHA_dgrad_params &params,
void *dsoftmax_sum_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
bool is_causal,
int num_splits) {
set_params_fprop(params,
b, seqlen_q, seqlen_k, h, d,
......@@ -154,7 +157,8 @@ void set_params_dgrad(FMHA_dgrad_params &params,
softmax_lse_d,
p_dropout,
softmax_scale,
is_causal);
is_causal,
num_splits);
// Set the pointers and strides.
params.dq_ptr = dq.data_ptr();
......@@ -186,6 +190,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const bool zero_tensors,
const bool is_causal,
const bool return_softmax,
const int num_splits,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
......@@ -286,12 +291,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
num_splits);
run_fmha_fp16_sm80(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = launch_params.elts_per_thread;
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
at::PhiloxCudaState rng_engine_inputs;
if( is_dropout ) {
......@@ -440,7 +447,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
/*num_splits=*/1);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
......@@ -560,7 +568,8 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
/*num_splits=*/1);
launch_params.params.blockmask = static_cast<int *>(blockmask.data_ptr());
run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true);
......@@ -706,7 +715,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
softmax_d.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
is_causal,
/*num_splits=*/1);
params.blockmask = static_cast<int *>(blockmask.data_ptr());
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
......
......@@ -127,6 +127,8 @@ struct FMHA_fprop_params : public Qkv_params {
bool is_bf16;
bool is_causal;
int num_splits; // How many SMs per attention matrix.
};
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -33,6 +33,32 @@
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 95%
// of the best efficiency.
int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) {
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (efficiency[num_splits - 1] > 0.95 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
......@@ -75,7 +101,21 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(launch_params.params.b, launch_params.params.h);
// Automatically set num_splits to maximize occupancy
if (launch_params.params.num_splits <= 0) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
auto dprops = at::cuda::getCurrentDeviceProperties();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr int M = Kernel_traits::Cta_tile_p::M;
launch_params.params.num_splits = num_splits_heuristic_fwd(
launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount,
ctas_per_sm,
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
);
}
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
......@@ -103,10 +143,7 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
} else if( launch_params.params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
......
......@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph, const int loop_step_idx) {
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int steps, int step_stride, Prng &ph, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type;
......@@ -266,15 +266,23 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin_og = begin;
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
const int steps_og = steps;
steps -= begin - begin_og;
gmem_q.move(begin);
gmem_o.move(begin);
gmem_o_tmp.move(begin);
if (Return_softmax) { gmem_s.move(begin); }
gmem_softmax_lse.move(begin);
steps -= begin;
gmem_q.move(begin + blockIdx.z);
gmem_o.move(begin + blockIdx.z);
gmem_o_tmp.move(begin + blockIdx.z);
if (Return_softmax) {
gmem_s.move(begin + blockIdx.z);
}
gmem_softmax_lse.move(begin + blockIdx.z);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("begin = %d, steps = %d\n", begin, steps);
// }
......@@ -362,8 +370,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
for (int l = blockIdx.z; l < steps; l += step_stride) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) {
// printf("l = %d\n", l);
// }
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
......@@ -380,9 +391,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
if (!Is_first) { gmem_o_tmp.load(out, 0); }
// Trigger the load for the next Q values.
if( l < steps - 1) {
if (l + step_stride < steps) {
gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.move(step_stride);
gmem_q.load();
}
......@@ -395,27 +406,28 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0)) {
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// }
// }
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
smem_softmax_lse.store_pair(p_prev_lse, l % 2);
smem_softmax_lse.store_pair(p_prev_lse);
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
}
// Trigger the load for the next LSE values.
if( l < steps - 1) {
if (l + step_stride < steps) {
if (!Is_first) {
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
step_stride);
}
}
......@@ -490,11 +502,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
softmax.template pack<elem_type>(frag_p);
if (Return_softmax) {
gmem_s.store(frag_p, mask);
gmem_s.move();
gmem_s.move(step_stride);
}
// Commit the values for Q into shared memory.
if(l < steps - 1) {
if (l + step_stride < steps) {
gmem_q.commit(gemm_q_k.smem_q);
}
......@@ -548,7 +560,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
}
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
if ((!Is_first) && o_rows_are_valid) {
smem_softmax_lse.load(p_prev_scale_o, rows, l % 2);
smem_softmax_lse.load(p_prev_scale_o, rows);
}
// if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
......@@ -594,7 +606,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
}
}
gmem_softmax_lse.move();
gmem_softmax_lse.move(step_stride);
// Load from shared memory.
if (!Is_first) {
......@@ -627,22 +639,21 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Output the values.
if (is_final_write) {
gmem_o.template store<elem_type>(out, 0);
gmem_o.move();
gmem_o.move(step_stride);
} else {
gmem_o_tmp.store(out, 0);
}
// Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_o_tmp.move(); }
if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); }
gemm_q_k.reload_k();
// Make sure we are reading from the correct buffer.
gemm_q_k.smem_q.move_to_next_read_buffer();
// Trigger the load from shared memory for the next series of Q values.
if(l < steps - 1) {
if (l + step_stride < steps) {
gemm_q_k.reload_q();
}
} // Outer loop over the sequence length.
}
......@@ -672,14 +683,14 @@ inline __device__ void device_1xN_loop(const Params &params) {
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph, 0);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, gridDim.z, ph, 0);
} else {
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph, 0);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, gridDim.z, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph, loop_step_idx);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, gridDim.z, ph, loop_step_idx);
}
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph, max_loop_steps - 1);
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, gridDim.z, ph, max_loop_steps - 1);
}
}
......
......@@ -14,10 +14,16 @@ def _get_block_size(device, head_dim, is_dropout):
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax, generator=None):
dropout_p, softmax_scale, causal, return_softmax, num_splits=0,
generator=None):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse, *rest = flash_attn_cuda.fwd(
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, generator
softmax_scale, False, causal, return_softmax, num_splits, generator
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment