Commit f66603cb authored by Tri Dao's avatar Tri Dao
Browse files

Support batch size > 64K by swapping grid.x and grid.y

parent 450b64fe
...@@ -456,9 +456,9 @@ struct Gmem_summary_stats { ...@@ -456,9 +456,9 @@ struct Gmem_summary_stats {
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) { : ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.x; const int bidh = blockIdx.y;
// The block index. // The block index.
// size_t bidx = bidb * params.h + bidh; // size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh; uint32_t bidx = bidb * params.h + bidh;
......
...@@ -45,7 +45,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_ ...@@ -45,7 +45,7 @@ void run_fmha_block_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
} }
dim3 grid(params.h, params.b); dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params); kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
......
...@@ -118,9 +118,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params, ...@@ -118,9 +118,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.x; const int bidh = blockIdx.y;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
...@@ -729,9 +729,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) { ...@@ -729,9 +729,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.x; const int bidh = blockIdx.y;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
......
...@@ -68,7 +68,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro ...@@ -68,7 +68,7 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fpro
return; return;
} }
dim3 grid(launch_params.params.h, launch_params.params.b); dim3 grid(launch_params.params.b, launch_params.params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>( kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params); launch_params.params);
......
...@@ -497,9 +497,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so ...@@ -497,9 +497,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
inline __device__ void device_block_1xN_loop(const Params &params) { inline __device__ void device_block_1xN_loop(const Params &params) {
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.x; const int bidh = blockIdx.y;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
......
...@@ -44,7 +44,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params ...@@ -44,7 +44,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
FMHA_CHECK_CUDA(cudaFuncSetAttribute( FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
} }
dim3 grid(params.h, params.b); dim3 grid(params.b, params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params); kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaPeekAtLastError());
} }
......
...@@ -119,9 +119,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng ...@@ -119,9 +119,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.x; const int bidh = blockIdx.y;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
...@@ -683,9 +683,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) { ...@@ -683,9 +683,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N;
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.x; const int bidh = blockIdx.y;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
......
...@@ -68,7 +68,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para ...@@ -68,7 +68,7 @@ void run_fmha_fp16_sm80_loop_(Launch_params<Fused_multihead_attention_fprop_para
return; return;
} }
dim3 grid(launch_params.params.h, launch_params.params.b); dim3 grid(launch_params.params.b, launch_params.params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>( kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params); launch_params.params);
......
...@@ -621,9 +621,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so ...@@ -621,9 +621,9 @@ template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_so
inline __device__ void device_1xN_loop(const Params &params) { inline __device__ void device_1xN_loop(const Params &params) {
// The block index for the batch. // The block index for the batch.
const int bidb = blockIdx.y; const int bidb = blockIdx.x;
// The block index for the head. // The block index for the head.
const int bidh = blockIdx.x; const int bidh = blockIdx.y;
// The thread index. // The thread index.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment