Commit 8c6609ae authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Support all dimensions up to 6k (if divisible by 8)

parent 8a2ece89
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh"
#include "static_switch.h"
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
bool has_residual = launch_params.params.x1 != nullptr;
bool has_rowscale = launch_params.params.rowscale != nullptr;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
});
}
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// TD [2022-04-22] Disable most of these to speed up compile time
// REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
...@@ -10,10 +10,13 @@ ...@@ -10,10 +10,13 @@
#include <curand_kernel.h> #include <curand_kernel.h>
#include "ln.h" #include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace layer_norm { namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_rowscale> template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) { void ln_fwd_kernel(FwdParams params) {
...@@ -73,57 +76,70 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -73,57 +76,70 @@ void ln_fwd_kernel(FwdParams params) {
curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
} }
const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
Wvec gamma[LDGS]; Wvec gamma[LDGS];
Wvec beta[LDGS]; Wvec beta[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx); if (Is_even_cols || (it < num_valid_ldgs)) {
beta[it].load_from(params.beta, idx); gamma[it].load_from(params.gamma, idx);
idx += VEC_COLS_PER_LDG; beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
} }
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t rowscale_val = Has_rowscale ? compute_t(rowscale[row]) : 1.0f; const compute_t rowscale_val = params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row]);
index_t idx = row * Ktraits::VEC_COLS + c; index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
compute_t xf[LDGS * NUM_ELTS]; compute_t xf[LDGS * NUM_ELTS];
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
Ivec x0; if (Is_even_cols || (it < num_valid_ldgs)) {
Rvec x1; Ivec x0;
Rvec x; Rvec x1;
Mvec dmask; Rvec x;
x0.load_from(params.x0, idx); Mvec dmask;
if (Has_residual) { x1.load_from(params.x1, idx); } x0.load_from(params.x0, idx);
#pragma unroll if (Has_residual) { x1.load_from(params.x1, idx); }
for( int jt = 0; jt < NUM_ELTS; jt++ ) { #pragma unroll
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// the more efficient curand_uniform4. // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
mask_t keep = true; // the more efficient curand_uniform4.
if (Is_dropout) { mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
float rand = curand_uniform(&state); compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
keep = mask_t(rand <= params.dropout_keep_p); compute_t x_ij;
if (Has_residual) {
compute_t x1_ij = compute_t(x1.data.elt[jt]);
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
} else {
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
} }
compute_t x0_ij = Has_rowscale ? compute_t(x0.data.elt[jt]) * rowscale_val : compute_t(x0.data.elt[jt]); if (save_x) { x.store_to(params.x, idx); }
compute_t x_ij; if (Is_dropout) { dmask.store_to(params.dmask, idx); }
if (Has_residual) { idx += VEC_COLS_PER_LDG;
compute_t x1_ij = compute_t(x1.data.elt[jt]);
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
} else {
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
} }
if (save_x) { x.store_to(params.x, idx); }
if (Is_dropout) { dmask.store_to(params.dmask, idx); }
idx += VEC_COLS_PER_LDG;
} }
stats_t s = stats.compute(xf, rn); static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now");
const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG;
const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG;
const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG;
// Need to convert to int, otherwise the subtraction will wrap around.
auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int {
const index_t valid_partial_vecs_in_warp =
std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)),
int(THREADS_PER_WARP));
return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS;
};
stats_t s = stats.template compute<Is_even_cols>(
xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS
);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s); compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s); compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
...@@ -132,28 +148,104 @@ void ln_fwd_kernel(FwdParams params) { ...@@ -132,28 +148,104 @@ void ln_fwd_kernel(FwdParams params) {
mu_ptr[row] = mu; mu_ptr[row] = mu;
} }
compute_t rs = rsqrtf(rn * m2 + params.epsilon); compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon);
if( bidn == 0 && warp_n == 0 && lane == 0 ) { if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs; rs_ptr[row] = rs;
} }
idx = row * Ktraits::VEC_COLS + c; idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
Ovec z; if (Is_even_cols || (it < num_valid_ldgs)) {
#pragma unroll Ovec z;
for( int jt = 0; jt < NUM_ELTS; jt++ ) { #pragma unroll
output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu)); for( int jt = 0; jt < NUM_ELTS; jt++ ) {
output_t g_ij = gamma[it].data.elt[jt]; compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - mu));
output_t b_ij = beta[it].data.elt[jt]; compute_t g_ij = gamma[it].data.elt[jt];
z.data.elt[jt] = (g_ij * y_ij + b_ij); compute_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = output_t(g_ij * y_ij + b_ij);
}
z.store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
} }
z.store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
} }
} }
} }
} // namespace layer_norm } // namespace layer_norm
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
bool has_residual = launch_params.params.x1 != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
});
}
...@@ -530,20 +530,20 @@ struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> { ...@@ -530,20 +530,20 @@ struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T, typename int_t>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll #pragma unroll
for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
// Exchange // Exchange
T n_b = warp_shuffle_down(n_a, step); int_t n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step); T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step); T m2_b = warp_shuffle_down(m2_a, step);
// Update // Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both.
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
const T delta = m_a - m_b; const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
...@@ -647,23 +647,26 @@ struct Stats<T, 1, WARPS_M, WARPS_N> { ...@@ -647,23 +647,26 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
smem1_ = smem0_ + WARPS_M * WARPS_N; smem1_ = smem0_ + WARPS_M * WARPS_N;
} }
template<uint32_t N> template<bool Is_even_cols, uint32_t N, typename function_t>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) { inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
stats_t * smem = use0_ ? smem0_ : smem1_; stats_t * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_; use0_ = !use0_;
// Compute warp local for all WARPS_N // Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); const auto warp_n = warp_stats_.reducer_.warp_n_;
stats_t warp_stats = warp_stats_.compute(elts, warp_rn); const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n));
stats_t warp_stats = warp_stats_.template compute<Is_even_cols>(
elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts
);
//Each warp warp leader stores its stats //Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_; const auto lane = warp_stats_.reducer_.lane_;
if( lane == 0 ) { if( lane == 0 ) {
smem[warp_n] = warp_stats; smem[warp_n] = warp_stats;
} }
__syncthreads(); __syncthreads();
T n = Zeros<T>::get(); int n = 0;;
T m = Zeros<T>::get(); T m = Zeros<T>::get();
T m2 = Zeros<T>::get(); T m2 = Zeros<T>::get();
...@@ -671,7 +674,7 @@ struct Stats<T, 1, WARPS_M, WARPS_N> { ...@@ -671,7 +674,7 @@ struct Stats<T, 1, WARPS_M, WARPS_N> {
static_assert(WARPS_N <= 32); static_assert(WARPS_N <= 32);
if(lane < WARPS_N){ if(lane < WARPS_N){
stats_t result = smem[lane]; stats_t result = smem[lane];
n = N * THREADS_PER_WARP; n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane);
m = layer_norm::Get<0>::of<stats_t, T>(result); m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result); m2 = layer_norm::Get<1>::of<stats_t, T>(result);
} }
...@@ -703,23 +706,29 @@ struct Stats<T, 1, WARPS_M, 1> { ...@@ -703,23 +706,29 @@ struct Stats<T, 1, WARPS_M, 1> {
{ {
} }
template<uint32_t N> template<bool Is_even_cols, uint32_t N, typename function_t>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) { inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor,
// const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) {
function_t valid_elts_in_warp_fn, const int num_valid_elts = N) {
auto sum = Sum<T>(); auto sum = Sum<T>();
T m = Zeros<T>::get(); T m = Zeros<T>::get();
#pragma unroll #pragma unroll
for( int it = 0; it < N; it++ ) { for( int it = 0; it < N; it++ ) {
m += elts[it]; if (Is_even_cols || (it < num_valid_elts)) {
m += elts[it];
}
} }
m = reducer_.allreduce(m, sum) * rn; m = reducer_.allreduce(m, sum) * row_norm_factor;
T m2 = Zeros<T>::get(); T m2 = Zeros<T>::get();
#pragma unroll #pragma unroll
for( int it = 0; it < N; it++ ) { for( int it = 0; it < N; it++ ) {
T diff = (elts[it] - m); if (Is_even_cols || (it < num_valid_elts)) {
m2 += diff * diff; T diff = (elts[it] - m);
m2 += diff * diff;
}
} }
m2 = reducer_.allreduce(m2, sum); m2 = reducer_.allreduce(m2, sum);
......
...@@ -108,8 +108,30 @@ ext_modules.append( ...@@ -108,8 +108,30 @@ ext_modules.append(
name="dropout_layer_norm", name="dropout_layer_norm",
sources=[ sources=[
"ln_api.cpp", "ln_api.cpp",
"ln_fwd_cuda_kernel.cu", "ln_fwd_256.cu",
"ln_bwd_semi_cuda_kernel.cu", "ln_bwd_256.cu",
"ln_fwd_512.cu",
"ln_bwd_512.cu",
"ln_fwd_768.cu",
"ln_bwd_768.cu",
"ln_fwd_1024.cu",
"ln_bwd_1024.cu",
"ln_fwd_1280.cu",
"ln_bwd_1280.cu",
"ln_fwd_1536.cu",
"ln_bwd_1536.cu",
"ln_fwd_2048.cu",
"ln_bwd_2048.cu",
"ln_fwd_2560.cu",
"ln_bwd_2560.cu",
"ln_fwd_3072.cu",
"ln_bwd_3072.cu",
"ln_fwd_4096.cu",
"ln_bwd_4096.cu",
"ln_fwd_5120.cu",
"ln_bwd_5120.cu",
"ln_fwd_6144.cu",
"ln_bwd_6144.cu",
], ],
extra_compile_args={ extra_compile_args={
"cxx": ["-O3"] + generator_flag, "cxx": ["-O3"] + generator_flag,
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import torch import torch
from torch.nn import init from torch.nn import init
# from apex._autocast_utils import _cast_if_autocast_enabled
import dropout_layer_norm import dropout_layer_norm
...@@ -145,7 +144,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No ...@@ -145,7 +144,7 @@ def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=No
class DropoutAddLayerNorm(torch.nn.Module): class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.5, eps=1e-5, residual_in_fp32=False, def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None): device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
......
...@@ -24,8 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 ...@@ -24,8 +24,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
(torch.float32, torch.float32)] (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) # @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) @pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
# @pytest.mark.parametrize('hidden_size', [768])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale): dropout_p, has_residual, has_rowscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
...@@ -148,7 +147,13 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh ...@@ -148,7 +147,13 @@ def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weigh
[(torch.float16, torch.float16), (torch.float16, torch.float32), [(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)] (torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) # @pytest.mark.parametrize('has_rowscale', [False])
# @pytest.mark.parametrize('has_residual', [True])
# @pytest.mark.parametrize('dropout_p', [0.0])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float32, torch.float32)])
# @pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
@pytest.mark.parametrize('hidden_size', [192, 256, 384, 768, 1024, 1280, 1536, 1600, 2048, 2560, 3000, 3072, 4096, 5120, 6144])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale): dropout_p, has_residual, has_rowscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment