Commit 871db479 authored by Tri Dao's avatar Tri Dao
Browse files

Don't need to run configure for the forward pass

parent 7fc39832
...@@ -294,7 +294,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -294,7 +294,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
is_causal, is_causal,
num_splits); num_splits);
run_fmha_fp16_sm80(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
...@@ -307,7 +306,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q ...@@ -307,7 +306,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
} }
run_fmha_fp16_sm80(launch_params, /*configure=*/false); run_fmha_fp16_sm80(launch_params);
std::vector<at::Tensor> result = {softmax_lse}; std::vector<at::Tensor> result = {softmax_lse};
if (return_softmax) {result.push_back(s);} if (return_softmax) {result.push_back(s);}
...@@ -453,9 +452,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size ...@@ -453,9 +452,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
// We're gonna reset the rng state in Python after this kernel, so the counter offset // We use a custom RNG that increases the offset by batch_size * nheads * 32.
// here doesn't matter at all. We just choose an arbitrary number. int64_t counter_offset = params.b * params.h * 32;
int64_t counter_offset = 4;
if( is_dropout ) { if( is_dropout ) {
// See Note [Acquire lock when using random generators] // See Note [Acquire lock when using random generators]
......
...@@ -191,7 +191,7 @@ struct Launch_params{ ...@@ -191,7 +191,7 @@ struct Launch_params{
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure); void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params);
void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream); void run_fmha_dgrad_fp16_sm80(const FMHA_dgrad_params &params, cudaStream_t stream);
......
...@@ -65,22 +65,10 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { ...@@ -65,22 +65,10 @@ __global__ void fmha_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
} }
template<typename Kernel_traits> template<typename Kernel_traits>
void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params, void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params) {
const bool configure) {
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr int M = Kernel_traits::Cta_tile_p::M;
size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
launch_params.elts_per_thread = elts_per_head;
return;
}
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping // Don't need smem_size_softmax_lse if we're not looping
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>() const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
...@@ -123,38 +111,37 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -123,38 +111,37 @@ void run_fmha_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
}); });
} }
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params) {
const bool configure) {
FP16_SWITCH(launch_params.params.is_bf16, [&] { FP16_SWITCH(launch_params.params.is_bf16, [&] {
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
if (launch_params.params.d == 16) { if (launch_params.params.d == 16) {
if( launch_params.params.seqlen_k == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if( launch_params.params.seqlen_k == 256 ) { } else if( launch_params.params.seqlen_k == 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else { } else {
// TD [2022-05-15] 512 gives wrong results rn // TD [2022-05-15] 512 gives wrong results rn
// using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<512, 16, 16, 1, 4, 0x08u, elem_type>;
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} }
} else if (launch_params.params.d == 32) { } else if (launch_params.params.d == 32) {
if( launch_params.params.seqlen_k == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if( launch_params.params.seqlen_k >= 256 ) { } else if( launch_params.params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} }
} else if (launch_params.params.d == 64) { } else if (launch_params.params.d == 64) {
if( launch_params.params.seqlen_k == 128 ) { if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} else if( launch_params.params.seqlen_k >= 256 ) { } else if( launch_params.params.seqlen_k >= 256 ) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} }
} else if (launch_params.params.d == 128) { } else if (launch_params.params.d == 128) {
// TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory // TD [2022-10-21]: Previously for SM80 we use block size 256 and keep K in shared memory
...@@ -166,30 +153,30 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -166,30 +153,30 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
// For causal=True, block size 128 seems always faster (for small & large batch size). // For causal=True, block size 128 seems always faster (for small & large batch size).
// So we're just gonna use block size 128 for simplicity. // So we're just gonna use block size 128 for simplicity.
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
} }
// if (launch_params.params.d == 64) { // if (launch_params.params.d == 64) {
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u, elem_type>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>; // // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } // }
// if (launch_params.params.d == 64) { // if (launch_params.params.d == 64) {
// if( launch_params.params.seqlen_k == 128 ) { // if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if( launch_params.params.seqlen_k >= 256 ) { // } else if( launch_params.params.seqlen_k >= 256 ) {
// if (dprops->major == 8 && dprops->minor >= 0) { // if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else if (dprops->major == 7 && dprops->minor == 5) { // } else if (dprops->major == 7 && dprops->minor == 5) {
// if (launch_params.is_dropout) { // Need to use the same block size as backward // if (launch_params.is_dropout) { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else { // } else {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } // }
// } // }
// } // }
...@@ -197,16 +184,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, ...@@ -197,16 +184,16 @@ void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
// if (launch_params.params.d == 128) { // if (launch_params.params.d == 128) {
// if( launch_params.params.seqlen_k == 128 ) { // if( launch_params.params.seqlen_k == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else { // } else {
// if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) { // if (dprops->major == 8 && dprops->minor >= 0 && !launch_params.is_dropout) {
// // TD [2022-06-05] Keep K in registers to reduce register spilling // // TD [2022-06-05] Keep K in registers to reduce register spilling
// // Gives about 6% speedup compared to using block size 128. // // Gives about 6% speedup compared to using block size 128.
// using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } else { // Need to use the same block size as backward // } else { // Need to use the same block size as backward
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>; // using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure); // run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params);
// } // }
// } // }
// } // }
......
...@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){ ...@@ -197,7 +197,7 @@ constexpr size_t get_dynamic_smem_size(){
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int steps, int step_stride, Prng &ph, const int loop_step_idx) { inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using elem_type = typename Kernel_traits::elem_type; using elem_type = typename Kernel_traits::elem_type;
...@@ -250,6 +250,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i ...@@ -250,6 +250,9 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
// How many steps to jump per iteration, which is the same as params.num_splits.
const int step_stride = gridDim.z;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx); const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
// if( binfo.stop_early() ) return; // if( binfo.stop_early() ) return;
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
...@@ -683,14 +686,14 @@ inline __device__ void device_1xN_loop(const Params &params) { ...@@ -683,14 +686,14 @@ inline __device__ void device_1xN_loop(const Params &params) {
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) { if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, gridDim.z, ph, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph, 0);
} else { } else {
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c; const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, gridDim.z, ph, 0); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, gridDim.z, ph, loop_step_idx); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph, loop_step_idx);
} }
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, gridDim.z, ph, max_loop_steps - 1); fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph, max_loop_steps - 1);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment