Commit 9765f725 authored by flyingdown's avatar flyingdown
Browse files

Develop

parent b5d7745d
#pragma once #pragma once
#include <unordered_map> #include <unordered_map>
#include <functional>
#if defined(__HIP_PLATFORM_HCC__)
#include "hip/hip_fp16.h"
#include "hip/hip_bfloat16.h"
#else
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#endif
namespace layer_norm { namespace layer_norm {
...@@ -121,7 +127,11 @@ extern BwdRegistry BWD_FUNCS; ...@@ -121,7 +127,11 @@ extern BwdRegistry BWD_FUNCS;
using fp32 = float; using fp32 = float;
using fp16 = half; using fp16 = half;
#if defined(__HIP_PLATFORM_HCC__)
using bf16 = hip_bfloat16;
#else
using bf16 = nv_bfloat16; using bf16 = nv_bfloat16;
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -25,8 +25,8 @@ namespace layer_norm { ...@@ -25,8 +25,8 @@ namespace layer_norm {
// Create registries and provide runtime versions of config hash functions. // Create registries and provide runtime versions of config hash functions.
FwdRegistry FWD_FUNCS; // FwdRegistry FWD_FUNCS;
BwdRegistry BWD_FUNCS; // BwdRegistry BWD_FUNCS;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
using namespace layer_norm; using namespace layer_norm;
BwdRegistry layer_norm::BWD_FUNCS;
template< template<
typename weight_t, typename weight_t,
typename input_t, typename input_t,
...@@ -35,7 +37,7 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -35,7 +37,7 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
if( configure_params ) { if( configure_params ) {
int ctas_per_sm; int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0; launch_params.barrier_size = 0;
...@@ -52,7 +54,11 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -52,7 +54,11 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
} }
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
#if defined(__HIP_PLATFORM_HCC__)
CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
#else
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
#endif
} }
auto stream = launch_params.stream; auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col; auto ctas_per_col = launch_params.params.ctas_per_col;
...@@ -63,7 +69,11 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -63,7 +69,11 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params; void *params_ = (void *)&launch_params.params;
#if defined(__HIP_PLATFORM_HCC__)
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
#else
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream); cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
#endif
} }
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE, using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
...@@ -85,156 +95,156 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params ...@@ -85,156 +95,156 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); // REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); // REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); // REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); // REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); // REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); // REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); // REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); // REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); // REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); // REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); // REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); // REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
using namespace layer_norm; using namespace layer_norm;
FwdRegistry layer_norm::FWD_FUNCS;
template< template<
typename weight_t, typename weight_t,
typename input_t, typename input_t,
...@@ -34,7 +36,7 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -34,7 +36,7 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
if( configure_params ) { if( configure_params ) {
int ctas_per_sm; int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0; launch_params.barrier_size = 0;
...@@ -51,7 +53,11 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -51,7 +53,11 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
} }
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
#if defined(__HIP_PLATFORM_HCC__)
CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
#else
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
#endif
} }
auto stream = launch_params.stream; auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col; auto ctas_per_col = launch_params.params.ctas_per_col;
...@@ -62,7 +68,11 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -62,7 +68,11 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params; void *params_ = (void *)&launch_params.params;
#if defined(__HIP_PLATFORM_HCC__)
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
#else
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream); cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
#endif
} }
} }
...@@ -71,155 +81,155 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params ...@@ -71,155 +81,155 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); // REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); // REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); // REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16); // REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); // REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); // REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); // REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); // REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); // REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16); // REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); // REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); // REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); // REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); // REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); // REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); // REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16); // REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); // REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16); // REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16); REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); // REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16); // REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); // REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); // REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
#pragma once #pragma once
#if defined(__HIP_PLATFORM_HCC__)
#include "ln_utils.cuh"
#else
#include "ln.h" #include "ln.h"
#endif
namespace layer_norm { namespace layer_norm {
......
...@@ -2,8 +2,13 @@ ...@@ -2,8 +2,13 @@
#include <cassert> #include <cassert>
#include <cuda_bf16.h> #if defined(__HIP_PLATFORM_HCC__)
#include "hip/hip_fp16.h"
#include "hip/hip_bfloat16.h"
#else
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h>
#endif
#include "ln.h" #include "ln.h"
...@@ -88,7 +93,11 @@ struct Sum { ...@@ -88,7 +93,11 @@ struct Sum {
template<typename T> template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
#if defined(__HIP_PLATFORM_HCC__)
return __shfl_xor(x, idx);
#else
return __shfl_xor_sync(uint32_t(-1), x, idx); return __shfl_xor_sync(uint32_t(-1), x, idx);
#endif
} }
template<> template<>
...@@ -98,7 +107,11 @@ inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx ...@@ -98,7 +107,11 @@ inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx
template<typename T> template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
#if defined(__HIP_PLATFORM_HCC__)
return __shfl_down(x, idx);
#else
return __shfl_down_sync(uint32_t(-1), x, idx); return __shfl_down_sync(uint32_t(-1), x, idx);
#endif
} }
template<> template<>
...@@ -188,10 +201,12 @@ struct TypeToVec2<half> { ...@@ -188,10 +201,12 @@ struct TypeToVec2<half> {
using Type = half2; using Type = half2;
}; };
#if 0
template<> template<>
struct TypeToVec2<nv_bfloat16> { struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162; using Type = nv_bfloat162;
}; };
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -241,6 +256,37 @@ struct Converter<float2, half2>{ ...@@ -241,6 +256,37 @@ struct Converter<float2, half2>{
} }
}; };
#if defined(__HIP_PLATFORM_HCC__)
template<>
struct Converter<float, half>{
static inline __device__ half convert(const float &x) {
return __float2half(x);
}
};
template<>
struct Converter<half, float>{
static inline __device__ float convert(const half &x) {
return __half2float(x);
}
};
template<>
struct Converter<float, hip_bfloat16>{
static inline __device__ hip_bfloat16 convert(const float &x) {
return hip_bfloat16::round_to_bfloat16(x);
}
};
template<>
struct Converter<hip_bfloat16, float>{
static inline __device__ float convert(const hip_bfloat16 &x) {
return float(x);
}
};
#endif
#if 0
template<> template<>
struct Converter<float2, nv_bfloat162>{ struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) { static inline __device__ nv_bfloat162 convert(const float2 &x) {
...@@ -258,6 +304,7 @@ struct Converter<float2, nv_bfloat162>{ ...@@ -258,6 +304,7 @@ struct Converter<float2, nv_bfloat162>{
#endif #endif
} }
}; };
#endif
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
...@@ -331,10 +378,18 @@ struct InterCTASync { ...@@ -331,10 +378,18 @@ struct InterCTASync {
} }
inline __device__ void spin_wait_(int *barrier, int step, int expected) { inline __device__ void spin_wait_(int *barrier, int step, int expected) {
#if defined(__HIP_PLATFORM_HCC__)
atomicAdd(barrier, step);
for( int found = -1; found != expected; ) {
// asm volatile("global_load_dword %0, %1, off;" : "=v"(found) : "v"(barrier));
found = atomicCAS(barrier, expected, expected);
}
#else
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for( int found = -1; found != expected; ) { for( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
} }
#endif
} }
inline __device__ void sync(){ inline __device__ void sync(){
...@@ -553,8 +608,13 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_ac ...@@ -553,8 +608,13 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_ac
m2_a = m2_ab; m2_a = m2_ab;
} }
// Intra-warp broadcast (only lane 0 has valid stats). // Intra-warp broadcast (only lane 0 has valid stats).
#if defined(__HIP_PLATFORM_HCC__)
m_a = __shfl(m_a, 0);
m2_a = __shfl(m2_a, 0);
#else
m_a = __shfl_sync(uint32_t(-1), m_a, 0); m_a = __shfl_sync(uint32_t(-1), m_a, 0);
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
#endif
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -235,8 +235,8 @@ class TestFastLayerNorm(unittest.TestCase): ...@@ -235,8 +235,8 @@ class TestFastLayerNorm(unittest.TestCase):
self.assertAll(test_(256, 2, h, fp32, fp32)) self.assertAll(test_(256, 2, h, fp32, fp32))
self.assertAll(test_(256, 2, h, fp16, fp16)) self.assertAll(test_(256, 2, h, fp16, fp16))
self.assertAll(test_(256, 2, h, fp32, fp16)) self.assertAll(test_(256, 2, h, fp32, fp16))
self.assertAll(test_(256, 2, h, bf16, bf16)) # self.assertAll(test_(256, 2, h, bf16, bf16))
self.assertAll(test_(256, 2, h, fp32, bf16)) # self.assertAll(test_(256, 2, h, fp32, bf16))
def test_run_benchmark(self): def test_run_benchmark(self):
for (S, B, hidden_size, runs) in ( for (S, B, hidden_size, runs) in (
...@@ -255,7 +255,8 @@ class TestFastLayerNorm(unittest.TestCase): ...@@ -255,7 +255,8 @@ class TestFastLayerNorm(unittest.TestCase):
def test_compat_with_autocast(self): def test_compat_with_autocast(self):
autocast_dtypes = ( autocast_dtypes = (
(torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) # (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
(torch.half,)
) )
input_shape = (512, 32, 768) input_shape = (512, 32, 768)
layer_norm = FastLayerNorm(input_shape[-1]).cuda() layer_norm = FastLayerNorm(input_shape[-1]).cuda()
......
...@@ -438,23 +438,46 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl. ...@@ -438,23 +438,46 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
if "--fast_layer_norm" in sys.argv: if "--fast_layer_norm" in sys.argv:
sys.argv.remove("--fast_layer_norm") sys.argv.remove("--fast_layer_norm")
raise_if_cuda_home_none("--fast_layer_norm") # raise_if_cuda_home_none("--fast_layer_norm")
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] # cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) # _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11: # if int(bare_metal_major) >= 11:
cc_flag.append("-gencode") # cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") # cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is None and not IS_ROCM_PYTORCH: if CUDA_HOME is None and not IS_ROCM_PYTORCH:
raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
# Check, if CUDA11 is installed for compute capability 8.0 if CUDA_HOME is not None:
cc_flag = [] # Check, if CUDA11 is installed for compute capability 8.0
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) cc_flag = []
if int(bare_metal_major) >= 11: _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
cc_flag.append('-gencode') if int(bare_metal_major) >= 11:
cc_flag.append('arch=compute_80,code=sm_80') cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
else:
hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros
ext_modules.append(
CUDAExtension(
name="fast_layer_norm",
sources=[
"apex/contrib/csrc/layer_norm/ln_api.cpp",
"apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu",
"apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu",
],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros + generator_flag,
"nvcc": [
"-O3",
'-U__HIP_NO_HALF_OPERATORS__',
'-U__HIP_NO_HALF_CONVERSIONS__',
"-I./apex/contrib/csrc/layer_norm/",
] + version_dependent_macros + generator_flag,
},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")],
)
)
if "--fmha" in sys.argv: if "--fmha" in sys.argv:
sys.argv.remove("--fmha") sys.argv.remove("--fmha")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment