"examples/vscode:/vscode.git/clone" did not exist on "83b112a145709a9dd06b9c1f2db69d95e4c7a3b9"
Commit a44f48df authored by Tri Dao's avatar Tri Dao
Browse files

Split fwd on the seqlen_q dimension

parent 1aa6d7d9
...@@ -54,7 +54,8 @@ void set_params_fprop(FMHA_fprop_params &params, ...@@ -54,7 +54,8 @@ void set_params_fprop(FMHA_fprop_params &params,
void *softmax_lse_d, void *softmax_lse_d,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
bool is_causal) { bool is_causal,
int num_splits) {
Data_type acc_type = DATA_TYPE_FP32; Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16; Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
...@@ -117,6 +118,7 @@ void set_params_fprop(FMHA_fprop_params &params, ...@@ -117,6 +118,7 @@ void set_params_fprop(FMHA_fprop_params &params,
set_alpha(params.scale_dropout, params.rp_dropout, data_type); set_alpha(params.scale_dropout, params.rp_dropout, data_type);
params.is_causal = is_causal; params.is_causal = is_causal;
params.num_splits = num_splits;
} }
void set_params_dgrad(FMHA_dgrad_params &params, void set_params_dgrad(FMHA_dgrad_params &params,
...@@ -142,7 +144,8 @@ void set_params_dgrad(FMHA_dgrad_params &params, ...@@ -142,7 +144,8 @@ void set_params_dgrad(FMHA_dgrad_params &params,
void *dsoftmax_sum_d, void *dsoftmax_sum_d,
float p_dropout, float p_dropout,
float softmax_scale, float softmax_scale,
bool is_causal) { bool is_causal,
int num_splits) {
set_params_fprop(params, set_params_fprop(params,
b, seqlen_q, seqlen_k, h, d, b, seqlen_q, seqlen_k, h, d,
...@@ -154,7 +157,8 @@ void set_params_dgrad(FMHA_dgrad_params &params, ...@@ -154,7 +157,8 @@ void set_params_dgrad(FMHA_dgrad_params &params,
softmax_lse_d, softmax_lse_d,
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); is_causal,
num_splits);
// Set the pointers and strides. // Set the pointers and strides.
params.dq_ptr = dq.data_ptr(); params.dq_ptr = dq.data_ptr();
...@@ -186,6 +190,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -186,6 +190,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const bool zero_tensors, const bool zero_tensors,
const bool is_causal, const bool is_causal,
const bool return_softmax, const bool return_softmax,
const int num_splits,
c10::optional<at::Generator> gen_) { c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
...@@ -286,12 +291,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -286,12 +291,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); is_causal,
num_splits);
run_fmha_fp16_sm80(launch_params, /*configure=*/ true); run_fmha_fp16_sm80(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
int64_t counter_offset = launch_params.elts_per_thread; // We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
if( is_dropout ) { if( is_dropout ) {
...@@ -440,7 +447,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -440,7 +447,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.data_ptr(), softmax_d.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); is_causal,
/*num_splits=*/1);
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
...@@ -560,7 +568,8 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t ...@@ -560,7 +568,8 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
softmax_lse.data_ptr(), softmax_lse.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); is_causal,
/*num_splits=*/1);
launch_params.params.blockmask = static_cast<int *>(blockmask.data_ptr()); launch_params.params.blockmask = static_cast<int *>(blockmask.data_ptr());
run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true); run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true);
...@@ -706,7 +715,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size ...@@ -706,7 +715,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
softmax_d.data_ptr(), softmax_d.data_ptr(),
p_dropout, p_dropout,
softmax_scale, softmax_scale,
is_causal); is_causal,
/*num_splits=*/1);
params.blockmask = static_cast<int *>(blockmask.data_ptr()); params.blockmask = static_cast<int *>(blockmask.data_ptr());
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
......
...@@ -127,6 +127,8 @@ struct FMHA_fprop_params : public Qkv_params { ...@@ -127,6 +127,8 @@ struct FMHA_fprop_params : public Qkv_params {
bool is_bf16; bool is_bf16;
bool is_causal; bool is_causal;
int num_splits; // How many SMs per attention matrix.
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -33,6 +33,32 @@ ...@@ -33,6 +33,32 @@
#include "fmha.h" #include "fmha.h"
#include "fmha_fprop_kernel_1xN.h" #include "fmha_fprop_kernel_1xN.h"
// Find the number of splits that maximizes the occupancy. For example, if we have
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
// splits as that would incur more HBM reads/writes.
// So we find the best efficiency, then find the smallest number of splits that gets 95%
// of the best efficiency.
int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) {
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (efficiency[num_splits - 1] > 0.95 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params); fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
...@@ -75,7 +101,21 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -75,7 +101,21 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
} }
dim3 grid(launch_params.params.b, launch_params.params.h); // Automatically set num_splits to maximize occupancy
if (launch_params.params.num_splits <= 0) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
auto dprops = at::cuda::getCurrentDeviceProperties();
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
constexpr int M = Kernel_traits::Cta_tile_p::M;
launch_params.params.num_splits = num_splits_heuristic_fwd(
launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount,
ctas_per_sm,
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
);
}
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>( kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params); launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
...@@ -103,10 +143,7 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -103,10 +143,7 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
if( launch_params.params.seqlen_k == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k == 256 ) { } else if( launch_params.params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
} }
......
...@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){ ...@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph, const int loop_step_idx) { inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int steps, int step_stride, Prng &ph, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type; using elem_type = typename Kernel_traits::elem_type;
...@@ -266,15 +266,23 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -266,15 +266,23 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Wind gmem tiles to the correct position. // Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin_og = begin; int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; // We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
// For example, threadblock with blockIdx.z == 1 must process row indices that are
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
const int steps_og = steps; const int steps_og = steps;
steps -= begin - begin_og; steps -= begin;
gmem_q.move(begin); gmem_q.move(begin + blockIdx.z);
gmem_o.move(begin); gmem_o.move(begin + blockIdx.z);
gmem_o_tmp.move(begin); gmem_o_tmp.move(begin + blockIdx.z);
if (Return_softmax) { gmem_s.move(begin); } if (Return_softmax) {
gmem_softmax_lse.move(begin); gmem_s.move(begin + blockIdx.z);
}
gmem_softmax_lse.move(begin + blockIdx.z);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("begin = %d, steps = %d\n", begin, steps); // printf("begin = %d, steps = %d\n", begin, steps);
// } // }
...@@ -362,8 +370,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -362,8 +370,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx); Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
// Load over the entire sequence length. // Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) { for (int l = blockIdx.z; l < steps; l += step_stride) {
if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) {
// printf("l = %d\n", l);
// }
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
// Declare the accumulators for the 1st gemm. // Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
...@@ -380,9 +391,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -380,9 +391,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
if (!Is_first) { gmem_o_tmp.load(out, 0); } if (!Is_first) { gmem_o_tmp.load(out, 0); }
// Trigger the load for the next Q values. // Trigger the load for the next Q values.
if( l < steps - 1) { if (l + step_stride < steps) {
gemm_q_k.smem_q.move_to_next_write_buffer(); gemm_q_k.smem_q.move_to_next_write_buffer();
gmem_q.move(); gmem_q.move(step_stride);
gmem_q.load(); gmem_q.load();
} }
...@@ -395,27 +406,28 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -395,27 +406,28 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Apply the mask. // Apply the mask.
softmax.apply_mask(mask); softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads(); __syncthreads();
} }
// if (!Is_first) { // if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0)) {
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]); // printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
// } // }
// } // }
// Compute the max. // Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2]; float p_max[Mma_tile_p::MMAS_M * 2];
if (!Is_first) { if (!Is_first) {
smem_softmax_lse.store_pair(p_prev_lse, l % 2); smem_softmax_lse.store_pair(p_prev_lse);
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; } // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; } for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
} }
// Trigger the load for the next LSE values. // Trigger the load for the next LSE values.
if( l < steps - 1) { if (l + step_stride < steps) {
if (!Is_first) { if (!Is_first) {
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse)); gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
step_stride);
} }
} }
...@@ -490,11 +502,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -490,11 +502,11 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
softmax.template pack<elem_type>(frag_p); softmax.template pack<elem_type>(frag_p);
if (Return_softmax) { if (Return_softmax) {
gmem_s.store(frag_p, mask); gmem_s.store(frag_p, mask);
gmem_s.move(); gmem_s.move(step_stride);
} }
// Commit the values for Q into shared memory. // Commit the values for Q into shared memory.
if(l < steps - 1) { if (l + step_stride < steps) {
gmem_q.commit(gemm_q_k.smem_q); gmem_q.commit(gemm_q_k.smem_q);
} }
...@@ -548,7 +560,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -548,7 +560,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
} }
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP]; float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
if ((!Is_first) && o_rows_are_valid) { if ((!Is_first) && o_rows_are_valid) {
smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); smem_softmax_lse.load(p_prev_scale_o, rows);
} }
// if (!Is_first) { // if (!Is_first) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
...@@ -594,7 +606,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -594,7 +606,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]); reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
} }
} }
gmem_softmax_lse.move(); gmem_softmax_lse.move(step_stride);
// Load from shared memory. // Load from shared memory.
if (!Is_first) { if (!Is_first) {
...@@ -627,22 +639,21 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -627,22 +639,21 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Output the values. // Output the values.
if (is_final_write) { if (is_final_write) {
gmem_o.template store<elem_type>(out, 0); gmem_o.template store<elem_type>(out, 0);
gmem_o.move(); gmem_o.move(step_stride);
} else { } else {
gmem_o_tmp.store(out, 0); gmem_o_tmp.store(out, 0);
} }
// Move to the next part of the output. // Move to the next part of the output.
if (!(Is_first && Is_last)) { gmem_o_tmp.move(); } if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); }
gemm_q_k.reload_k(); gemm_q_k.reload_k();
// Make sure we are reading from the correct buffer. // Make sure we are reading from the correct buffer.
gemm_q_k.smem_q.move_to_next_read_buffer(); gemm_q_k.smem_q.move_to_next_read_buffer();
// Trigger the load from shared memory for the next series of Q values. // Trigger the load from shared memory for the next series of Q values.
if(l < steps - 1) { if (l + step_stride < steps) {
gemm_q_k.reload_q(); gemm_q_k.reload_q();
} }
} // Outer loop over the sequence length. } // Outer loop over the sequence length.
} }
...@@ -672,14 +683,14 @@ inline __device__ void device_1xN_loop(const Params &params) { ...@@ -672,14 +683,14 @@ inline __device__ void device_1xN_loop(const Params &params) {
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) { if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, gridDim.z, ph, 0);
} else { } else {
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, gridDim.z, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph, loop_step_idx); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, gridDim.z, ph, loop_step_idx);
} }
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph, max_loop_steps - 1); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, gridDim.z, ph, max_loop_steps - 1);
} }
} }
......
...@@ -14,10 +14,16 @@ def _get_block_size(device, head_dim, is_dropout): ...@@ -14,10 +14,16 @@ def _get_block_size(device, head_dim, is_dropout):
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax, generator=None): dropout_p, softmax_scale, causal, return_softmax, num_splits=0,
generator=None):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse, *rest = flash_attn_cuda.fwd( softmax_lse, *rest = flash_attn_cuda.fwd(
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, generator softmax_scale, False, causal, return_softmax, num_splits, generator
) )
# if out.isnan().any() or softmax_lse.isnan().any(): # if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint() # breakpoint()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment