Commit 8c6609ae authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Support all dimensions up to 6k (if divisible by 8)

parent 8a2ece89
This CUDA extension implements fused dropout + residual + LayerNorm, based on This CUDA extension implements fused dropout + residual + LayerNorm, building on
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
We add dropout and residual, and make it work for both pre-norm and post-norm architecture. We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
We also make it work for more hidden dimensions (all dimensions divisible by 8, up to 6144).
This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`. If you want to use it for dimensions larger than 6k, please file an issue.
It has only been tested on A100s. This extension has only been tested on A100s.
```sh ```sh
cd csrc/layer_norm && pip install . cd csrc/layer_norm && pip install .
......
...@@ -64,6 +64,8 @@ struct ParamsBase { ...@@ -64,6 +64,8 @@ struct ParamsBase {
void *gamma; void *gamma;
void *rowscale; void *rowscale;
float inverse_cols;
float dropout_keep_p; float dropout_keep_p;
float dropout_scale; float dropout_scale;
......
...@@ -129,6 +129,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -129,6 +129,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(gamma.sizes() == beta.sizes()); TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols); TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
TORCH_CHECK(epsilon >= 0.f); TORCH_CHECK(epsilon >= 0.f);
...@@ -156,8 +157,10 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -156,8 +157,10 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator()); gen_, at::cuda::detail::getDefaultCUDAGenerator());
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
// Request the kernel launcher. // Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
// Query the kernel-specific launch parameters. // Query the kernel-specific launch parameters.
launcher(launch_params, true); launcher(launch_params, true);
...@@ -178,6 +181,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: ...@@ -178,6 +181,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params.z = z.data_ptr(); params.z = z.data_ptr();
params.epsilon = epsilon; params.epsilon = epsilon;
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
if (dropout_p > 0.f) { if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
...@@ -263,6 +267,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -263,6 +267,8 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
} }
auto hidden_size = gamma.numel(); auto hidden_size = gamma.numel();
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
TORCH_CHECK(mu.numel() == rows); TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes()); TORCH_CHECK(mu.sizes() == rsigma.sizes());
...@@ -285,7 +291,9 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -285,7 +291,9 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
launcher(launch_params, true, /*prenorm=*/false); launcher(launch_params, true, /*prenorm=*/false);
...@@ -308,6 +316,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd ...@@ -308,6 +316,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
params.dbeta_part = dbeta_part.data_ptr(); params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr(); params.dgamma_part = dgamma_part.data_ptr();
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
if( launch_params.barrier_size > 0 ) { if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this? // TODO Any way to avoid this?
...@@ -385,6 +394,8 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // ...@@ -385,6 +394,8 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
} }
auto hidden_size = gamma.numel(); auto hidden_size = gamma.numel();
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK((hidden_size % 8 == 0) && (hidden_size <= 6144));
TORCH_CHECK(mu.numel() == rows); TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes()); TORCH_CHECK(mu.sizes() == rsigma.sizes());
...@@ -407,8 +418,9 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // ...@@ -407,8 +418,9 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
// TODO: how to set template param for launcher auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); const int multiple = hidden_size <= 1536 ? 256 : (hidden_size <= 3072 ? 512 : 1024);
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, round_multiple(hidden_size, multiple));
launcher(launch_params, true, /*prenorm=*/true); launcher(launch_params, true, /*prenorm=*/true);
...@@ -432,6 +444,7 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // ...@@ -432,6 +444,7 @@ std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, //
params.dbeta_part = dbeta_part.data_ptr(); params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr(); params.dgamma_part = dgamma_part.data_ptr();
params.dropout_scale = 1.f / (1.f - dropout_p); params.dropout_scale = 1.f / (1.f - dropout_p);
params.inverse_cols = 1.f / float(params.cols);
if( launch_params.barrier_size > 0 ) { if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this? // TODO Any way to avoid this?
......
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
\ No newline at end of file
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
\ No newline at end of file
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
\ No newline at end of file
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 512, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 512, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
\ No newline at end of file
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp32, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp32, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
\ No newline at end of file
#include "ln_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
#pragma once #pragma once
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace layer_norm { namespace layer_norm {
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_rowscale> template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Is_even_cols>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) { void ln_bwd_kernel(layer_norm::BwdParams params) {
...@@ -59,14 +64,18 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -59,14 +64,18 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
Sum<reduce_t> sum; Sum<reduce_t> sum;
constexpr float rn = 1.f / float(COLS); const index_t num_valid_ldgs =
((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + Ktraits::VEC_COLS_PER_LDG) / Ktraits::VEC_COLS_PER_LDG;
Wvec gamma[LDGS]; Wvec gamma[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
gamma[it].load_from(params.gamma, idx); gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads! // last blocks with syncthreads!
// grid stride over rows // grid stride over rows
...@@ -74,16 +83,18 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -74,16 +83,18 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row]; const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row]; const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast<const input_t *>(params.rowscale)[row]) : 1.0f; const compute_t rowscale_val =
params.rowscale == nullptr ? 1.0f : compute_t(static_cast<const input_t *>(params.rowscale)[row]);
Mvec dmask[LDGS]; Mvec dmask[LDGS];
Rvec dx[LDGS]; Rvec dx[LDGS];
compute_t dy[LDGS * NUM_ELTS]; compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS]; compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f; compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f; compute_t mdyy_local = 0.f;
index_t idx = row * Ktraits::VEC_COLS + c; index_t idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Rvec x; Rvec x;
Ovec dz; Ovec dz;
dz.load_from(params.dz, idx); dz.load_from(params.dz, idx);
...@@ -95,8 +106,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -95,8 +106,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
for( int jt = 0; jt < NUM_ELTS; jt++ ) { for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x.data.elt[jt]; compute_t x_tmp = x.data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r); compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]) * compute_t(dz.data.elt[jt]);
dy_tmp *= compute_t(dz.data.elt[jt]);
compute_t dz_tmp = dz.data.elt[jt]; compute_t dz_tmp = dz.data.elt[jt];
mdy_local += dy_tmp; mdy_local += dy_tmp;
...@@ -109,14 +119,16 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -109,14 +119,16 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
dz_sum[it].data.elt[jt] += dz_tmp; dz_sum[it].data.elt[jt] += dz_tmp;
} }
} }
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn; mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * params.inverse_cols;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn; mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * params.inverse_cols;
idx = row * Ktraits::VEC_COLS + c; idx = row * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
Ivec dx0; Ivec dx0;
Rvec dx1; Rvec dx1;
#pragma unroll #pragma unroll
...@@ -126,7 +138,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -126,7 +138,7 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; } if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res; compute_t dx0_tmp_res = dx_tmp_res * rowscale_val;
if (Is_dropout) { if (Is_dropout) {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f; dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
} else { } else {
...@@ -137,17 +149,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -137,17 +149,20 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
dx0.store_to(params.dx0, idx); dx0.store_to(params.dx0, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
}
} // end: grid stride loop } // end: grid stride loop
if( WARPS_M == 1 ) { if( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c; idx = r * params.cols / Ktraits::ELTS_PER_LDG + c;
#pragma unroll #pragma unroll
for( int it = 0; it < LDGS; it++ ) { for( int it = 0; it < LDGS; it++ ) {
if (Is_even_cols || (it < num_valid_ldgs)) {
dz_sum[it].store_to(params.dbeta_part, idx); dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx); dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG; idx += Ktraits::VEC_COLS_PER_LDG;
} }
}
} else { } else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA // Finalize reduction of part dgamma and dbeta for this CTA
...@@ -188,21 +203,23 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { ...@@ -188,21 +203,23 @@ void ln_bwd_kernel(layer_norm::BwdParams params) {
} }
} }
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx; const index_t num_valid_writes
= (params.cols - 1 - tidx + Ktraits::THREADS_PER_CTA) / Ktraits::THREADS_PER_CTA;
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * params.cols + tidx;
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * params.cols + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) { for( int jt = 0; jt < NUM_RES; jt++ ) {
if (Is_even_cols || (jt < num_valid_writes)) {
*dgamma_part = cta_dzy_sum[jt]; *dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA; dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt]; *dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA; dbeta_part += Ktraits::THREADS_PER_CTA;
} }
} }
}
} }
template<typename Kernel_traits> template<typename Kernel_traits, bool Is_even_cols>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params) void ln_bwd_finalize_kernel(BwdParams params)
{ {
...@@ -236,8 +253,10 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -236,8 +253,10 @@ void ln_bwd_finalize_kernel(BwdParams params)
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local; Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local)); memset(&dbeta_local, 0, sizeof(dbeta_local));
if (Is_even_cols || col < params.cols) {
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col; // index_t idx = row * Kernel_traits::COLS + col;
index_t idx = row * params.cols + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part; Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx); dbeta_part.load_from(params.dbeta_part, idx);
...@@ -248,7 +267,7 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -248,7 +267,7 @@ void ln_bwd_finalize_kernel(BwdParams params)
dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
} }
} }
}
void * smem_gamma = smem_; void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
...@@ -305,6 +324,7 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -305,6 +324,7 @@ void ln_bwd_finalize_kernel(BwdParams params)
__syncthreads(); __syncthreads();
// Pack and store: 2-wide stores with half the threads. // Pack and store: 2-wide stores with half the threads.
if (Is_even_cols || col_out * 2 < params.cols) {
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type; using src_t = typename TypeToVec2<compute_t>::Type;
...@@ -324,5 +344,95 @@ void ln_bwd_finalize_kernel(BwdParams params) ...@@ -324,5 +344,95 @@ void ln_bwd_finalize_kernel(BwdParams params)
} }
} }
}
} }
} // namespace layer_norm } // namespace layer_norm
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr;
bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE;
BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f, IsEvenColsConst>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
});
});
});
});
}
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
#include "static_switch.h"
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr;
bool has_rowscale = launch_params.params.rowscale != nullptr;
BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
});
});
});
});
}
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// TD [2022-04-22] Disable most of these to speed up compile time
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
#include "ln_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment