Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -13,167 +13,147 @@ using namespace transformer_engine::rmsnorm; ...@@ -13,167 +13,147 @@ using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL> int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_tuned_(LaunchParams<BwdParams> &launch_params,
using Kernel_traits = const bool configure_params) { // NOLINT(*)
rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, using Kernel_traits =
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>; rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>; CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm; if (configure_params) {
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( int ctas_per_sm;
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
launch_params.params.ctas_per_row = CTAS_PER_ROW; &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_col = launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; launch_params.params.ctas_per_col =
launch_params.barrier_size = 0; launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.workspace_bytes = 0; launch_params.barrier_size = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) { launch_params.workspace_bytes = 0;
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.workspace_bytes = launch_params.params.ctas_per_col * launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
sizeof(typename Kernel_traits::reduce_t) * 2; Kernel_traits::CTAS_PER_ROW *
} sizeof(typename Kernel_traits::reduce_t) * 2;
return;
} }
return;
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { }
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES)); if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
} NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
auto stream = launch_params.stream; Kernel_traits::SMEM_BYTES));
auto ctas_per_col = launch_params.params.ctas_per_col; }
auto ctas_per_row = launch_params.params.ctas_per_row; auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if (ctas_per_row == 1) { auto ctas_per_row = launch_params.params.ctas_per_row;
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params); if (ctas_per_row == 1) {
} else { kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
}
using Kernel_traits_f =
Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t, output_t, compute_t, index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params); launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
}
using Kernel_traits_f =
Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t, output_t, compute_t, index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
} }
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN, typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL> int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_general_(LaunchParams<BwdParams> &launch_params,
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, // Instantiate kernel
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>; using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>; 1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows; // Configure kernel params
const int cols = launch_params.params.cols; const int rows = launch_params.params.rows;
int ctas_per_col = launch_params.params.ctas_per_col; const int cols = launch_params.params.cols;
int ctas_per_row = launch_params.params.ctas_per_row; int ctas_per_col = launch_params.params.ctas_per_col;
if (configure_params) { int ctas_per_row = launch_params.params.ctas_per_row;
int ctas_per_sm; if (configure_params) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, int ctas_per_sm;
Kernel_traits::THREADS_PER_CTA, 0); cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; Kernel_traits::THREADS_PER_CTA, 0);
ctas_per_row = ceil_div(cols, HIDDEN_SIZE); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
launch_params.params.ctas_per_row = ctas_per_row; ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_col = ctas_per_col; launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0; launch_params.barrier_size = 0;
if (launch_params.params.ctas_per_row > 1) { launch_params.workspace_bytes = 0;
launch_params.barrier_size = 2 * ctas_per_col; if (launch_params.params.ctas_per_row > 1) {
launch_params.workspace_bytes = (ctas_per_col * WARPS_M * ctas_per_row * launch_params.barrier_size = 2 * ctas_per_col;
sizeof(typename Kernel_traits::reduce_t) * 2); launch_params.workspace_bytes =
} (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2);
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
} }
return;
// Launch finalization kernel }
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1; // Launch kernel
constexpr uint32_t ELTS_N_PER_CTA_FINAL = auto stream = launch_params.stream;
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); dim3 grid(ctas_per_row * ctas_per_col);
auto kernel_final = dim3 block(Kernel_traits::THREADS_PER_CTA);
&rmsnorm_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL, if (ctas_per_row == 1) {
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>; kernel<<<grid, block, 0, stream>>>(launch_params.params);
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); } else {
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); void *params_ = reinterpret_cast<void *>(&launch_params.params);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params); cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t));
auto kernel_final =
&rmsnorm_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \ #define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \ WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
BYTES_PER_LDG_FINALIZE) \ void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
LaunchParams<BwdParams> \ launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
&launch_params, \ WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, \
const bool configure_params) { \ configure_params); \
launch_tuned_<WTYPE, \ } \
ITYPE, \ static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
OTYPE, \ reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
CTYPE, \ rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
uint32_t, \
HIDDEN_SIZE, \ #define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
CTAS_PER_ROW, \ BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
WARPS_M, \ void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
WARPS_N, \ LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
BYTES_PER_LDG, \ launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \ BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \ } \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -13,122 +13,120 @@ using namespace transformer_engine::rmsnorm; ...@@ -13,122 +13,120 @@ using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG> int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_tuned_(LaunchParams<FwdParams> &launch_params,
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, const bool configure_params) { // NOLINT(*)
HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>; using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>; CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm; if (configure_params) {
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( int ctas_per_sm;
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
launch_params.params.ctas_per_row = CTAS_PER_ROW; &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_col = launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; launch_params.params.ctas_per_col =
launch_params.barrier_size = 0; launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.workspace_bytes = 0; launch_params.barrier_size = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) { launch_params.workspace_bytes = 0;
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.workspace_bytes = launch_params.params.ctas_per_col * launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
sizeof(typename Kernel_traits::Stats::stats_t) * 2; Kernel_traits::CTAS_PER_ROW *
} sizeof(typename Kernel_traits::Stats::stats_t) * 2;
return;
}
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD,
stream>>>(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
} }
return;
}
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
}
} }
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG> typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_general_(LaunchParams<FwdParams> &launch_params,
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, const bool configure_params) { // NOLINT(*)
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG>; using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>; 1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows; // Configure kernel params
const int cols = launch_params.params.cols; const int rows = launch_params.params.rows;
int ctas_per_col = launch_params.params.ctas_per_col; const int cols = launch_params.params.cols;
int ctas_per_row = launch_params.params.ctas_per_row; int ctas_per_col = launch_params.params.ctas_per_col;
if (configure_params) { int ctas_per_row = launch_params.params.ctas_per_row;
int ctas_per_sm; if (configure_params) {
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( int ctas_per_sm;
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
ctas_per_row = ceil_div(cols, HIDDEN_SIZE); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
launch_params.params.ctas_per_row = ctas_per_row; ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_col = ctas_per_col; launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0; launch_params.barrier_size = 0;
if (launch_params.params.ctas_per_row > 1) { launch_params.workspace_bytes = 0;
launch_params.barrier_size = 2 * ctas_per_col; if (launch_params.params.ctas_per_row > 1) {
launch_params.workspace_bytes = launch_params.barrier_size = 2 * ctas_per_col;
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); launch_params.workspace_bytes =
} (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2);
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
} }
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ #define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \ LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
const bool configure_params) { \ launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \ WARPS_N, BYTES_PER_LDG>(launch_params, configure_params); \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \ } \
launch_params, configure_params); \ static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
} \ reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) #define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG) \
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
WARPS_M, WARPS_N, BYTES_PER_LDG) \ LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
LaunchParams<FwdParams> &launch_params, \ BYTES_PER_LDG>(launch_params, configure_params); \
const bool configure_params) { \ } \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \ static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \ reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
launch_params, configure_params); \ rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cfloat> #include <cfloat>
#include <cstdio> #include <cstdio>
#include "../utils.cuh" #include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -18,261 +19,260 @@ using namespace transformer_engine; ...@@ -18,261 +19,260 @@ using namespace transformer_engine;
template <typename Ktraits> template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel( __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel(
FwdParams params) { FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS }; enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t; using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t; using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t; using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec; using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec; using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec; using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats; using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t; using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[]; extern __shared__ char smem_[];
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW; const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW; const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP; const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N; const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N; const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m; const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs); compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS]; Wvec gamma[LDGS];
index_t idx = c; index_t idx = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx); x[it].load_from(params.x, idx);
idx += VEC_COLS_PER_LDG; #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
} }
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); stats_t s = stats.compute(xf, rn);
compute_t mu = Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = Get<1>::of<stats_t, compute_t>(s);
// reciprocal of root mean square
// we could optimize here to count mean square directly
compute_t rs = rsqrtf(rn * m2 + mu * mu + params.epsilon);
compute_t scale; if (bidn == 0 && warp_n == 0 && lane == 0) {
if (params.fp8_out) { rs_ptr[row] = rs;
scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { Ovec z[LDGS];
Ivec x[LDGS]; idx = row * Ktraits::VEC_COLS + c;
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll #pragma unroll
for (int it = 0; it < LDGS; it++) { for (int it = 0; it < LDGS; it++) {
x[it].load_from(params.x, idx);
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = compute_t(x[it].data.elt[jt]); compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]);
xf[it * NUM_ELTS + jt] = x_ij; compute_t g_ij = gamma[it].data.elt[jt];
} if (params.zero_centered_gamma) {
idx += VEC_COLS_PER_LDG; g_ij += 1;
} }
compute_t temp_output = g_ij * y_ij;
stats_t s = stats.compute(xf, rn); if (params.fp8_out) {
__builtin_assume(amax >= 0);
compute_t mu = Get<0>::of<stats_t, compute_t>(s); amax = fmaxf(amax, fabsf(temp_output));
compute_t m2 = Get<1>::of<stats_t, compute_t>(s); temp_output = temp_output * scale;
// reciprocal of root mean square
// we could optimize here to count mean square directly
compute_t rs = rsqrtf(rn * m2 + mu * mu + params.epsilon);
if (bidn == 0 && warp_n == 0 && lane == 0) {
rs_ptr[row] = rs;
} }
Ovec z[LDGS]; z[it].data.elt[jt] = output_t(temp_output);
idx = row * Ktraits::VEC_COLS + c; }
#pragma unroll z[it].store_to(params.z, idx);
for (int it = 0; it < LDGS; it++) { idx += VEC_COLS_PER_LDG;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
} }
if (params.fp8_out) { }
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); if (params.fp8_out) {
if (threadIdx.x == 0 && threadIdx.y == 0) { amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
static_assert(std::is_same<compute_t, float>::value); if (threadIdx.x == 0 && threadIdx.y == 0) {
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); static_assert(std::is_same<compute_t, float>::value);
} atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
} }
}
} }
template <typename Ktraits> template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel( __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel(
FwdParams params) { FwdParams params) {
enum { LDGS = Ktraits::LDGS }; enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t; using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t; using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t; using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t; using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t; using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec; using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec; using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec; using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec; using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x; const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP; const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP; const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N; const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N; const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M; const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP; const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row; const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row; const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col; const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row; const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m; const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
(bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + lane); // Order threads by warp x cta x lane
lane); // Order threads by warp x cta x lane
// Objects for stats reductions
// Objects for stats reductions using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>; constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; __shared__ char smem_[SMEM_BYTES];
__shared__ char smem_[SMEM_BYTES]; Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); Sum<compute_t> sum;
Sum<compute_t> sum; const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
// Load weights Cvec gamma[LDGS];
Cvec gamma[LDGS];
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
// fp8 factors
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in; Ivec x_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col); x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col);
gamma_in.to(gamma[it]); x_in.to(x[it]);
}
// fp8 factors
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { // Compute variance
const int row = cta_row + warp_m; compute_t sqsigma = 0.f;
// Load input
Cvec x[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x_in;
x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col);
x_in.to(x[it]);
}
// Compute variance
compute_t sqsigma = 0.f;
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) { if (col + jt < params.cols) {
compute_t diff = x[it].data.elt[jt]; compute_t diff = x[it].data.elt[jt];
sqsigma += diff * diff; sqsigma += diff * diff;
}
}
} }
sqsigma = reducer.allreduce(sqsigma, sum) * rn; }
compute_t rs = rsqrtf(sqsigma + params.epsilon); }
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
// Write statistics // Write statistics
if (gidn == 0 && row < params.rows) { if (gidn == 0 && row < params.rows) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs); compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
rs_ptr[row] = rs; rs_ptr[row] = rs;
} }
// Compute output // Compute output
#pragma unroll #pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) { it++, col += gdimn * NUM_ELTS) {
// Compute output values // Compute output values
Cvec z; Cvec z;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (x[it].data.elt[jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
z.data.elt[jt] = g_ij * y_ij;
}
// Apply fp8 factors
if (params.fp8_out) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) { compute_t y_ij = rs * (x[it].data.elt[jt]);
compute_t z_ij = z.data.elt[jt]; compute_t g_ij = gamma[it].data.elt[jt];
__builtin_assume(amax >= 0); if (params.zero_centered_gamma) {
amax = fmaxf(amax, fabsf(z_ij)); g_ij += 1;
z.data.elt[jt] = z_ij * scale;
}
}
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col);
} }
} z.data.elt[jt] = g_ij * y_ij;
}
// Finalize fp8 factors // Apply fp8 factors
if (params.fp8_out) { if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); #pragma unroll
if (threadIdx.x == 0) { for (int jt = 0; jt < NUM_ELTS; jt++) {
static_assert(std::is_same<compute_t, float>::value); if (col + jt < params.cols) {
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax); compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
} }
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col);
}
}
// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
} }
}
} }
} // namespace rmsnorm } // namespace rmsnorm
......
...@@ -5,80 +5,65 @@ ...@@ -5,80 +5,65 @@
************************************************************************/ ************************************************************************/
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "common.h" #include "common.h"
namespace transformer_engine { namespace transformer_engine {
size_t typeToSize(const transformer_engine::DType type) { size_t typeToSize(const transformer_engine::DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size; return TypeInfo<T>::size;); // NOLINT(*)
); // NOLINT(*)
} }
bool is_fp8_dtype(const transformer_engine::DType t) { bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 || return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2;
t == transformer_engine::DType::kFloat8E5M2;
} }
void CheckInputTensor(const Tensor &t, const std::string &name) { void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.data.dtype; const DType type = t.data.dtype;
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv // FP8 input needs to have scale_inv
NVTE_CHECK(t.scale_inv.dptr != nullptr, NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale.");
"FP8 input " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + ".");
"Scale is not supported for non-FP8 input " + name + "."); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr, NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input " + name + "."); "Scale_inv is not supported for non-FP8 input " + name + ".");
} }
NVTE_CHECK(t.data.dptr != nullptr, NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!");
"Input " + name + " is not allocated!");
} }
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.data.dtype; const DType type = t.data.dtype;
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, amax and scale_inv // FP8 output needs to have scale, amax and scale_inv
NVTE_CHECK(t.amax.dptr != nullptr, NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor.");
"FP8 output " + name + " must have amax tensor.");
NVTE_CHECK(t.amax.dtype == DType::kFloat32); NVTE_CHECK(t.amax.dtype == DType::kFloat32);
NVTE_CHECK(t.amax.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.amax.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale_inv.dptr != nullptr, NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale.");
"FP8 output " + name + " must have scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale.dptr != nullptr, NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale.");
"FP8 output " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale.dtype == DType::kFloat32); NVTE_CHECK(t.scale.dtype == DType::kFloat32);
NVTE_CHECK(t.scale.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.scale.shape == std::vector<size_t>{1});
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + ".");
"Scale is not supported for non-FP8 output " + name + "."); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr, NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 output " + name + "."); "Scale_inv is not supported for non-FP8 output " + name + ".");
} }
if (!allow_empty) { if (!allow_empty) {
NVTE_CHECK(t.data.dptr != nullptr, NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!");
"Output " + name + " is not allocated!");
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
NVTETensor nvte_create_tensor(void *dptr, NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax,
const NVTEShape shape, float *scale, float *scale_inv) {
const NVTEDType dtype,
float *amax,
float *scale,
float *scale_inv) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor; transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->data.dptr = dptr; ret->data.dptr = dptr;
ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim); ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
...@@ -97,11 +82,11 @@ void nvte_destroy_tensor(NVTETensor tensor) { ...@@ -97,11 +82,11 @@ void nvte_destroy_tensor(NVTETensor tensor) {
NVTEDType nvte_tensor_type(const NVTETensor tensor) { NVTEDType nvte_tensor_type(const NVTETensor tensor) {
return static_cast<NVTEDType>( return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor*>(tensor)->data.dtype); reinterpret_cast<const transformer_engine::Tensor *>(tensor)->data.dtype);
} }
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; NVTEShape ret;
ret.data = t.data.shape.data(); ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size(); ret.ndim = t.data.shape.size();
...@@ -109,40 +94,40 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -109,40 +94,40 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
} }
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr; return t.data.dptr;
} }
float *nvte_tensor_amax(const NVTETensor tensor) { float *nvte_tensor_amax(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!"); "Tensor's amax must have Float32 type!");
return reinterpret_cast<float*>(t.amax.dptr); return reinterpret_cast<float *>(t.amax.dptr);
} }
float *nvte_tensor_scale(const NVTETensor tensor) { float *nvte_tensor_scale(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!"); "Tensor's scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale.dptr); return reinterpret_cast<float *>(t.scale.dptr);
} }
float *nvte_tensor_scale_inv(const NVTETensor tensor) { float *nvte_tensor_scale_inv(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
"Tensor's inverse of scale must have Float32 type!"); "Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr); return reinterpret_cast<float *>(t.scale_inv.dptr);
} }
void nvte_tensor_pack_create(NVTETensorPack* pack) { void nvte_tensor_pack_create(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) { for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor); pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
} }
} }
void nvte_tensor_pack_destroy(NVTETensorPack* pack) { void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) { for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor*>(pack->tensors[i]); auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
delete t; delete t;
} }
} }
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/cast_transpose_noop.h> #include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <algorithm> #include <algorithm>
#include <cuda_runtime.h>
#include "../common.h" #include "../common.h"
#include "../util/rtc.h" #include "../util/rtc.h"
#include "../util/string.h" #include "../util/string.h"
...@@ -49,26 +48,18 @@ struct KernelConfig { ...@@ -49,26 +48,18 @@ struct KernelConfig {
/* Elements per L1 cache store to transposed output */ /* Elements per L1 cache store to transposed output */
size_t elements_per_store_t = 0; size_t elements_per_store_t = 0;
KernelConfig(size_t row_length, KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t otype_size,
size_t num_rows, size_t load_size_, size_t store_size_)
size_t itype_size, : load_size{load_size_}, store_size{store_size_} {
size_t otype_size,
size_t load_size_,
size_t store_size_)
: load_size{load_size_}
, store_size{store_size_} {
// Check that tiles are correctly aligned // Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128; constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0 if (load_size % itype_size != 0 || store_size % otype_size != 0 ||
|| store_size % otype_size != 0 cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) {
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
return; return;
} }
const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size; const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size; const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size;
valid = (row_length % row_tile_elements == 0 valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
&& num_rows % col_tile_elements == 0);
if (!valid) { if (!valid) {
return; return;
} }
...@@ -80,12 +71,9 @@ struct KernelConfig { ...@@ -80,12 +71,9 @@ struct KernelConfig {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count())); static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) / itype_size);
/ itype_size); elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) / otype_size);
elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) / otype_size);
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size)
/ otype_size);
} }
/* Compare by estimated cost */ /* Compare by estimated cost */
...@@ -104,8 +92,8 @@ struct KernelConfig { ...@@ -104,8 +92,8 @@ struct KernelConfig {
const auto &st2 = other.elements_per_store_t; const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count; const auto &p2 = other.active_sm_count;
const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2; const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2;
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1) / p1; const auto cost1 = (scale / l1 + scale / sc1 + scale / st1) / p1;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2) / p2; const auto cost2 = (scale / l2 + scale / sc2 + scale / st2) / p2;
return cost1 < cost2; return cost1 < cost2;
} else { } else {
return this->valid && !other.valid; return this->valid && !other.valid;
...@@ -114,16 +102,14 @@ struct KernelConfig { ...@@ -114,16 +102,14 @@ struct KernelConfig {
}; };
template <size_t load_size, size_t store_size, typename IType, typename OType> template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __global__ void __launch_bounds__(block_size)
__launch_bounds__(block_size) cast_transpose_general_kernel(const IType *__restrict__ const input,
cast_transpose_general_kernel(const IType * __restrict__ const input, const CType *__restrict__ const noop,
const CType * __restrict__ const noop, OType *__restrict__ const output_c,
OType * __restrict__ const output_c, OType *__restrict__ const output_t,
OType * __restrict__ const output_t, const CType *__restrict__ const scale_ptr,
const CType * __restrict__ const scale_ptr, CType *__restrict__ const amax_ptr, const size_t row_length,
CType * __restrict__ const amax_ptr, const size_t num_rows) {
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
...@@ -165,16 +151,16 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, ...@@ -165,16 +151,16 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles, computes amax, // Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers. // casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations]; OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
const CType in = input[row * row_length + col + j2]; const CType in = input[row * row_length + col + j2];
...@@ -190,24 +176,24 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, ...@@ -190,24 +176,24 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
} }
// Copy transposed output from registers to global memory // Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter]; shared_output_t[j1][i1] = local_output_t[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out; const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2; const size_t col = tile_col + j1 * nvec_in + j2;
if (col < row_length) { if (col < row_length) {
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) { if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
...@@ -229,18 +215,15 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, ...@@ -229,18 +215,15 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
} // namespace } // namespace
void cast_transpose(const Tensor &input, void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_,
const Tensor &noop, Tensor *transposed_output_, cudaStream_t stream) {
Tensor *cast_output_,
Tensor *transposed_output_,
cudaStream_t stream) {
Tensor &cast_output = *cast_output_; Tensor &cast_output = *cast_output_;
Tensor &transposed_output = *transposed_output_; Tensor &transposed_output = *transposed_output_;
// Check no-op flag // Check no-op flag
if (noop.data.dptr != nullptr) { if (noop.data.dptr != nullptr) {
size_t numel = 1; size_t numel = 1;
for (const auto& dim : noop.data.shape) { for (const auto &dim : noop.data.shape) {
numel *= dim; numel *= dim;
} }
NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, ".");
...@@ -254,16 +237,14 @@ void cast_transpose(const Tensor &input, ...@@ -254,16 +237,14 @@ void cast_transpose(const Tensor &input,
CheckOutputTensor(transposed_output, "transposed_output"); CheckOutputTensor(transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions.");
NVTE_CHECK(transposed_output.data.shape.size() == 2, NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions.");
"Transposed output must have 2 dimensions.");
const size_t row_length = input.data.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0]; const size_t num_rows = input.data.shape[0];
NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output.");
NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output.");
NVTE_CHECK(transposed_output.data.shape[0] == row_length, NVTE_CHECK(transposed_output.data.shape[0] == row_length,
"Wrong dimension of transposed output."); "Wrong dimension of transposed output.");
NVTE_CHECK(transposed_output.data.shape[1] == num_rows, NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output.");
"Wrong dimension of transposed output.");
// Check tensor pointers // Check tensor pointers
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
...@@ -276,118 +257,111 @@ void cast_transpose(const Tensor &input, ...@@ -276,118 +257,111 @@ void cast_transpose(const Tensor &input,
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor."); "Cast and transposed outputs need to share scale tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output.data.dtype, OutputType, input.data.dtype, InputType,
constexpr const char *itype_name = TypeInfo<InputType>::name; TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
constexpr const char *otype_name = TypeInfo<OutputType>::name; cast_output.data.dtype, OutputType,
constexpr size_t itype_size = sizeof(InputType); constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr size_t otype_size = sizeof(OutputType); constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType);
// Choose between runtime-compiled or statically-compiled kernel constexpr size_t otype_size = sizeof(OutputType);
const bool aligned = (row_length % THREADS_PER_WARP == 0
&& num_rows % THREADS_PER_WARP == 0); // Choose between runtime-compiled or statically-compiled kernel
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel const bool aligned =
// Pick kernel config (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
std::vector<KernelConfig> kernel_configs; if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
kernel_configs.reserve(16); // Pick kernel config
auto add_config = [&](size_t load_size, size_t store_size) { std::vector<KernelConfig> kernel_configs;
kernel_configs.emplace_back(row_length, num_rows, kernel_configs.reserve(16);
itype_size, otype_size, auto add_config = [&](size_t load_size, size_t store_size) {
load_size, store_size); kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size,
}; store_size);
add_config(8, 8); };
add_config(4, 8); add_config(8, 4); add_config(8, 8);
add_config(4, 4); add_config(4, 8);
add_config(2, 8); add_config(8, 2); add_config(8, 4);
add_config(2, 4); add_config(4, 2); add_config(4, 4);
add_config(2, 2); add_config(2, 8);
add_config(1, 8); add_config(8, 1); add_config(8, 2);
add_config(1, 4); add_config(4, 1); add_config(2, 4);
add_config(1, 2); add_config(2, 1); add_config(4, 2);
add_config(1, 1); add_config(2, 2);
const auto &kernel_config = *std::min_element(kernel_configs.begin(), add_config(1, 8);
kernel_configs.end()); add_config(8, 1);
NVTE_CHECK(kernel_config.valid, "invalid kernel config"); add_config(1, 4);
const size_t load_size = kernel_config.load_size; add_config(4, 1);
const size_t store_size = kernel_config.store_size; add_config(1, 2);
const size_t num_blocks = kernel_config.num_blocks; add_config(2, 1);
add_config(1, 1);
// Compile NVRTC kernel if needed and launch const auto &kernel_config =
auto& rtc_manager = rtc::KernelManager::instance(); *std::min_element(kernel_configs.begin(), kernel_configs.end());
const std::string kernel_label = concat_strings("cast_transpose" NVTE_CHECK(kernel_config.valid, "invalid kernel config");
",itype=", itype_name, const size_t load_size = kernel_config.load_size;
",otype=", otype_name, const size_t store_size = kernel_config.store_size;
",load_size=", load_size, const size_t num_blocks = kernel_config.num_blocks;
",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) { // Compile NVRTC kernel if needed and launch
std::string code = string_code_transpose_rtc_cast_transpose_cu; auto &rtc_manager = rtc::KernelManager::instance();
code = regex_replace(code, "__ITYPE__", itype_name); const std::string kernel_label = concat_strings(
code = regex_replace(code, "__OTYPE__", otype_name); "cast_transpose"
code = regex_replace(code, "__LOAD_SIZE__", load_size); ",itype=",
code = regex_replace(code, "__STORE_SIZE__", store_size); itype_name, ",otype=", otype_name, ",load_size=", load_size,
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); ",store_size=", store_size);
code = regex_replace(code, "__BLOCK_SIZE__", block_size); if (!rtc_manager.is_compiled(kernel_label)) {
rtc_manager.compile(kernel_label, std::string code = string_code_transpose_rtc_cast_transpose_cu;
"cast_transpose_optimized_kernel", code = regex_replace(code, "__ITYPE__", itype_name);
code, code = regex_replace(code, "__OTYPE__", otype_name);
"transformer_engine/common/transpose/rtc/cast_transpose.cu"); code = regex_replace(code, "__LOAD_SIZE__", load_size);
} code = regex_replace(code, "__STORE_SIZE__", store_size);
rtc_manager.launch(kernel_label, code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
num_blocks, block_size, 0, stream, code = regex_replace(code, "__BLOCK_SIZE__", block_size);
static_cast<const InputType *>(input.data.dptr), rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code,
reinterpret_cast<const CType *>(noop.data.dptr), "transformer_engine/common/transpose/rtc/cast_transpose.cu");
static_cast<OutputType*>(cast_output.data.dptr), }
static_cast<OutputType*>(transposed_output.data.dptr), rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
static_cast<const CType*>(cast_output.scale.dptr), static_cast<const InputType *>(input.data.dptr),
static_cast<CType*>(cast_output.amax.dptr), reinterpret_cast<const CType *>(noop.data.dptr),
row_length, num_rows); static_cast<OutputType *>(cast_output.data.dptr),
} else { // Statically-compiled general kernel static_cast<OutputType *>(transposed_output.data.dptr),
constexpr size_t load_size = 4; static_cast<const CType *>(cast_output.scale.dptr),
constexpr size_t store_size = 4; static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; } else { // Statically-compiled general kernel
constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; constexpr size_t load_size = 4;
const int num_blocks = (DIVUP(row_length, row_tile_size) constexpr size_t store_size = 4;
* DIVUP(num_rows, col_tile_size)); constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP;
cast_transpose_general_kernel<load_size, store_size, InputType, OutputType> constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
<<<num_blocks, block_size, 0, stream>>>( const int num_blocks =
static_cast<const InputType *>(input.data.dptr), (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
reinterpret_cast<const CType *>(noop.data.dptr), cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
static_cast<OutputType *>(cast_output.data.dptr), <<<num_blocks, block_size, 0, stream>>>(
static_cast<OutputType *>(transposed_output.data.dptr), static_cast<const InputType *>(input.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr), reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<CType *>(cast_output.amax.dptr), static_cast<OutputType *>(cast_output.data.dptr),
row_length, num_rows); static_cast<OutputType *>(transposed_output.data.dptr),
} static_cast<const CType *>(cast_output.scale.dptr),
); // NOLINT(*) static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_cast_transpose(const NVTETensor input, void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) {
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose); NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
noop, reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(transposed_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
void nvte_cast_transpose_with_noop(const NVTETensor input, void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop,
const NVTETensor noop, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop); NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
*reinterpret_cast<const Tensor*>(noop), reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(transposed_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
...@@ -4,16 +4,18 @@ ...@@ -4,16 +4,18 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <cfloat> #include <cfloat>
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include "../utils.cuh"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../common.h" #include "../common.h"
#include "../util/math.h" #include "../util/math.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -26,7 +28,7 @@ namespace { ...@@ -26,7 +28,7 @@ namespace {
constexpr size_t n_warps_per_tile = 8; constexpr size_t n_warps_per_tile = 8;
constexpr size_t desired_load_size = 8; constexpr size_t desired_load_size = 8;
constexpr size_t desired_store_size = 8; constexpr size_t desired_store_size = 8;
constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers
constexpr size_t desired_store_size_dact = 4; constexpr size_t desired_store_size_dact = 4;
constexpr size_t threads_per_warp = static_cast<size_t>(THREADS_PER_WARP); constexpr size_t threads_per_warp = static_cast<size_t>(THREADS_PER_WARP);
...@@ -38,443 +40,411 @@ static_assert(cast_transpose_num_threads <= max_threads_per_block); ...@@ -38,443 +40,411 @@ static_assert(cast_transpose_num_threads <= max_threads_per_block);
/* Performance heuristics for optimized kernel parameters */ /* Performance heuristics for optimized kernel parameters */
struct KernelConfig { struct KernelConfig {
size_t load_size = 0; // Vector load size size_t load_size = 0; // Vector load size
size_t store_size = 0; // Vector store size to transposed output size_t store_size = 0; // Vector store size to transposed output
bool valid = false; // Whether config is valid bool valid = false; // Whether config is valid
bool is_dact = false; // Whether dact is used bool is_dact = false; // Whether dact is used
size_t num_blocks = 0; // Number of CUDA blocks size_t num_blocks = 0; // Number of CUDA blocks
size_t active_sm_count = 0; // Number of active SMs size_t active_sm_count = 0; // Number of active SMs
size_t elements_per_load = 0; // Elements per L1 cache load size_t elements_per_load = 0; // Elements per L1 cache load
size_t elements_per_load_dact = 0; // Elements per L1 cache load dact size_t elements_per_load_dact = 0; // Elements per L1 cache load dact
size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output
size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output
KernelConfig(size_t row_length, KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t itype2_size,
size_t num_rows, size_t otype_size, size_t load_size_, size_t store_size_, bool is_dact_)
size_t itype_size, : load_size{load_size_}, store_size{store_size_}, is_dact{is_dact_} {
size_t itype2_size, if (is_dact) {
size_t otype_size, if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) {
size_t load_size_, return;
size_t store_size_, }
bool is_dact_) }
: load_size{load_size_}
, store_size{store_size_}
, is_dact{is_dact_} {
if (is_dact) {
if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) {
return;
}
}
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0
|| store_size % otype_size != 0
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
return;
}
/* row_tile_elements */
const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size;
/* col_tile_elements */
const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size;
const size_t num_tiles_x = row_length / tile_size_x;
const size_t num_tiles_y = num_rows / tile_size_y;
valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0);
if (!valid) {
return;
}
// Number of CUDA blocks // Check that tiles are correctly aligned
num_blocks = num_tiles_x * num_tiles_y; constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0 || store_size % otype_size != 0 ||
// Parameters for performance model cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs return;
active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm), }
static_cast<size_t>(cuda::sm_count())); /* row_tile_elements */
elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size;
/ itype_size); /* col_tile_elements */
elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size;
/ itype2_size); const size_t num_tiles_x = row_length / tile_size_x;
elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size) const size_t num_tiles_y = num_rows / tile_size_y;
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size) valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0);
/ otype_size); if (!valid) {
return;
} }
/* Compare by estimated cost */ // Number of CUDA blocks
bool operator<(const KernelConfig &other) const { num_blocks = num_tiles_x * num_tiles_y;
if (this->valid && other.valid) {
// cost ~ (1/elements_per_load // Parameters for performance model
// + 1/elements_per_load_dact constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
// + 1/elements_per_store_c active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm),
// + 1/elements_per_store_t) / active_sms static_cast<size_t>(cuda::sm_count()));
// Note: Integer arithmetic ensures stable ordering elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) / itype_size);
const auto &l1 = this->elements_per_load; elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) / itype2_size);
const auto &la1 = this->elements_per_load_dact; elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size) / otype_size);
const auto &sc1 = this->elements_per_store_c; elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size) / otype_size);
const auto &st1 = this->elements_per_store_t; }
const auto &p1 = this->active_sm_count;
const auto &l2 = other.elements_per_load; /* Compare by estimated cost */
const auto &la2 = other.elements_per_load_dact; bool operator<(const KernelConfig &other) const {
const auto &sc2 = other.elements_per_store_c; if (this->valid && other.valid) {
const auto &st2 = other.elements_per_store_t; // cost ~ (1/elements_per_load
const auto &p2 = other.active_sm_count; // + 1/elements_per_load_dact
const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1); // + 1/elements_per_store_c
const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1); // + 1/elements_per_store_t) / active_sms
const auto scale = scale1 * scale2; // Note: Integer arithmetic ensures stable ordering
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1 + (is_dact ? (scale / la1) : 0)) const auto &l1 = this->elements_per_load;
/ p1; const auto &la1 = this->elements_per_load_dact;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2 + (is_dact ? (scale / la2) : 0)) const auto &sc1 = this->elements_per_store_c;
/ p2; const auto &st1 = this->elements_per_store_t;
const auto &p1 = this->active_sm_count;
return cost1 < cost2; const auto &l2 = other.elements_per_load;
} else { const auto &la2 = other.elements_per_load_dact;
return this->valid && !other.valid; const auto &sc2 = other.elements_per_store_c;
} const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count;
const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1);
const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1);
const auto scale = scale1 * scale2;
const auto cost1 =
(scale / l1 + scale / sc1 + scale / st1 + (is_dact ? (scale / la1) : 0)) / p1;
const auto cost2 =
(scale / l2 + scale / sc2 + scale / st2 + (is_dact ? (scale / la2) : 0)) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
} }
}
}; };
template <bool IS_DBIAS, bool IS_FULL_TILE, int nvec_in, int nvec_out, typename OVec, typename CVec,
template <bool IS_DBIAS, bool IS_FULL_TILE, int nvec_in, int nvec_out, typename CType>
typename OVec, typename CVec, typename CType>
inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in], OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*) CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile, typename OVec::type *output_cast_tile,
const size_t current_place, const size_t current_place, const size_t stride,
const size_t stride,
const CType scale, const CType scale,
CType &amax, // NOLINT(*) CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane, const int dbias_shfl_src_lane,
const bool valid_store) { const bool valid_store) {
using OType = typename OVec::type; using OType = typename OVec::type;
using OVecC = Vec<OType, nvec_in>; using OVecC = Vec<OType, nvec_in>;
CVec step_dbias; CVec step_dbias;
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
step_dbias.clear(); step_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
}
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
} }
if (IS_FULL_TILE || valid_store) {
#pragma unroll out_cast.store_to(output_cast_tile, current_place + stride * i);
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
}
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
}
if (IS_FULL_TILE || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
} }
}
if constexpr (IS_DBIAS) {
#pragma unroll if constexpr (IS_DBIAS) {
for (unsigned int j = 0; j < nvec_in; ++j) { #pragma unroll
CType elt = step_dbias.data.elt[j]; for (unsigned int j = 0; j < nvec_in; ++j) {
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp CType elt = step_dbias.data.elt[j];
out_dbias.data.elt[j] += elt; elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
} out_dbias.data.elt[j] += elt;
} }
}
} }
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace, Tensor *workspace, const int nvec_out) {
const int nvec_out) { const size_t row_length = cast_output.data.shape[1];
const size_t row_length = cast_output.data.shape[1]; const size_t num_rows = cast_output.data.shape[0];
const size_t num_rows = cast_output.data.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
workspace->data.shape = {num_rows_partial_dbias, row_length}; workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->data.dtype = DType::kFloat32; workspace->data.dtype = DType::kFloat32;
} }
template<int nvec, typename ComputeType, typename OutputType> template <int nvec, typename ComputeType, typename OutputType>
__global__ void __global__ void __launch_bounds__(reduce_dbias_num_threads)
__launch_bounds__(reduce_dbias_num_threads) reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial,
reduce_dbias_kernel(OutputType* const dbias_output, const int row_length, const int num_rows) {
const ComputeType* const dbias_partial, using ComputeVec = Vec<ComputeType, nvec>;
const int row_length, using OutputVec = Vec<OutputType, nvec>;
const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= row_length) { if (thread_id * nvec >= row_length) {
return; return;
} }
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec;
const int stride_in_vec = row_length / nvec; const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec;
OutputType *const thread_out_base = dbias_output + thread_id * nvec;
ComputeVec ldg_vec; const int stride_in_vec = row_length / nvec;
ComputeVec acc_vec; acc_vec.clear();
for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}
OutputVec stg_vec; ComputeVec ldg_vec;
#pragma unroll ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll
for (int e = 0; e < nvec; ++e) { for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); acc_vec.data.elt[e] += ldg_vec.data.elt[e];
} }
stg_vec.store_to(thread_out_base, 0); }
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base, 0);
} }
template <typename InputType> template <typename InputType>
void reduce_dbias(const Tensor &workspace, void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length,
Tensor *dbias, const size_t num_rows, const int nvec_out, cudaStream_t stream) {
const size_t row_length, constexpr int reduce_dbias_store_bytes = 8; // stg.64
const size_t num_rows, constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType);
const int nvec_out,
cudaStream_t stream) { NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType); const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows =
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks =
const size_t reduce_dbias_row_length = row_length; DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec);
const size_t reduce_dbias_num_rows = DIVUP(num_rows,
static_cast<size_t>(nvec_out * THREADS_PER_WARP)); using DbiasOutputType = fp32;
const size_t reduce_dbias_num_blocks = DIVUP(row_length, reduce_dbias_kernel<reduce_dbias_nvec, DbiasOutputType, InputType>
reduce_dbias_num_threads * reduce_dbias_nvec); <<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>(
reinterpret_cast<InputType *>(dbias->data.dptr),
using DbiasOutputType = fp32; reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_kernel<reduce_dbias_nvec, DbiasOutputType, InputType> reduce_dbias_num_rows);
<<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>
(reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr),
reduce_dbias_row_length,
reduce_dbias_num_rows);
} }
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param, int nvec_in,
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param, int nvec_out, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
int nvec_in, int nvec_out, typename ParamOP, __global__ void __launch_bounds__(cast_transpose_num_threads)
ComputeType (*OP)(ComputeType, const ParamOP&)> cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length,
__global__ void const size_t num_rows, const size_t num_tiles) {
__launch_bounds__(cast_transpose_num_threads) using IType = typename Param::InputType;
cast_transpose_fused_kernel_notaligned(const Param param, using IType2 = typename Param::InputType2;
const size_t row_length, using OType = typename Param::OutputType;
const size_t num_rows, using CType = typename Param::ComputeType;
const size_t num_tiles) { using IVec = Vec<IType, nvec_in>;
using IType = typename Param::InputType; using IVec2 = Vec<IType2, nvec_in>;
using IType2 = typename Param::InputType2; using OVec = Vec<OType, nvec_out>;
using OType = typename Param::OutputType; using CVec = Vec<CType, nvec_in>;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>; extern __shared__ char scratch[];
using IVec2 = Vec<IType2, nvec_in>;
using OVec = Vec<OType, nvec_out>; const int warp_id = threadIdx.x / THREADS_PER_WARP;
using CVec = Vec<CType, nvec_in>; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x =
extern __shared__ char scratch[]; (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id =
const int warp_id = threadIdx.x / THREADS_PER_WARP; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; if (tile_id >= num_tiles) {
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) return;
/ (nvec_in * THREADS_PER_WARP); }
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile)
+ warp_id / n_warps_per_tile; const size_t tile_id_x = tile_id % num_tiles_x;
if (tile_id >= num_tiles) { const size_t tile_id_y = tile_id / num_tiles_x;
return;
} const size_t tile_offset =
(tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_offset_transp =
const size_t tile_id_y = tile_id / num_tiles_x; (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
const size_t tile_offset = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) const IType *const my_input_tile = param.input + tile_offset;
* THREADS_PER_WARP; const IType2 *const my_act_input_tile = param.act_input + tile_offset;
const size_t tile_offset_transp = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) OType *const my_output_c_tile = param.output_c + tile_offset;
* THREADS_PER_WARP; OType *const my_output_t_tile = param.output_t + tile_offset_transp;
CType *const my_partial_dbias_tile =
const IType * const my_input_tile = param.input + tile_offset; param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
const IType2 * const my_act_input_tile = param.act_input + tile_offset;
OType * const my_output_c_tile = param.output_c + tile_offset; const size_t stride = row_length / nvec_in;
OType * const my_output_t_tile = param.output_t + tile_offset_transp; const size_t output_stride = num_rows / nvec_out;
CType * const my_partial_dbias_tile = param.workspace const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
+ (tile_id_x * (nvec_in * THREADS_PER_WARP) const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
+ tile_id_y * row_length); const unsigned int tile_length =
row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const size_t stride = row_length / nvec_in; const unsigned int tile_height =
const size_t output_stride = num_rows / nvec_out; row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; OVec *const my_scratch =
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP reinterpret_cast<OVec *>(scratch) +
: row_length_rest; (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest; CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) IVec in[2][nvec_out];
+ (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) IVec2 act_in[2][nvec_out];
* (THREADS_PER_WARP + 1); const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); OVec out_space[n_iterations][nvec_in];
IVec in[2][nvec_out]; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
IVec2 act_in[2][nvec_out]; size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out;
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; unsigned int my_place =
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
OVec out_space[n_iterations][nvec_in]; CType amax = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out; CVec partial_dbias;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) if constexpr (IS_DBIAS) {
% THREADS_PER_WARP; partial_dbias.clear();
CType amax = 0; }
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
{
CVec partial_dbias; const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
if constexpr (IS_DBIAS) { #pragma unroll
partial_dbias.clear(); for (unsigned int i = 0; i < nvec_out; ++i) {
} if (valid_load) {
const size_t ld_offset = current_stride + my_place + stride * i;
{ in[0][i].load_from(my_input_tile, ld_offset);
const bool valid_load = my_place < tile_length && if constexpr (IS_DACT) {
warp_id_in_tile * n_iterations < tile_height; act_in[0][i].load_from(my_act_input_tile, ld_offset);
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
const size_t ld_offset = current_stride + my_place + stride * i;
in[0][i].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, ld_offset);
}
} else {
in[0][i].clear();
if constexpr (IS_DACT) {
act_in[0][i].clear();
}
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
} else {
in[current_in][j].clear();
if constexpr (IS_DACT) {
act_in[current_in][j].clear();
}
}
}
} }
CVec after_dact[nvec_out]; // NOLINT(*) } else {
#pragma unroll in[0][i].clear();
for (unsigned int j = 0; j < nvec_out; ++j) { if constexpr (IS_DACT) {
#pragma unroll act_in[0][i].clear();
for (unsigned int k = 0; k < nvec_in; ++k) {
if constexpr (IS_DACT) {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k])
* OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
}
}
} }
const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations) }
% THREADS_PER_WARP;
constexpr bool IS_FULL_TILE = false;
const bool valid_store = (my_place < tile_length)
&& (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_space[i], partial_dbias, my_output_c_tile, current_place,
stride, scale, amax, dbias_shfl_src_lane, valid_store);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_row += nvec_out;
} }
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int i = 0; i < n_iterations; ++i) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) const size_t current_place = current_stride + my_place;
% THREADS_PER_WARP] = out_space[j][i]; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load =
my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
} else {
in[current_in][j].clear();
if constexpr (IS_DACT) {
act_in[current_in][j].clear();
}
} }
__syncthreads(); }
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) }
% THREADS_PER_WARP; CVec after_dact[nvec_out]; // NOLINT(*)
current_stride = i * output_stride #pragma unroll
+ warp_id_in_tile * n_iterations * output_stride * nvec_in; for (unsigned int j = 0; j < nvec_out; ++j) {
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { #pragma unroll
const bool valid_store = my_place < tile_height; for (unsigned int k = 0; k < nvec_in; ++k) {
if (valid_store) { if constexpr (IS_DACT) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
current_stride + my_place); OP(act_in[current_in ^ 1][j].data.elt[k], {});
} } else {
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
current_stride += output_stride * nvec_in;
} }
__syncthreads(); }
} }
const int dbias_shfl_src_lane =
if constexpr (IS_DBIAS) { (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
my_dbias_scratch[threadIdx.x] = partial_dbias; constexpr bool IS_FULL_TILE = false;
__syncthreads(); const bool valid_store =
if (warp_id_in_tile == 0) { (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height);
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) { cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_space[i], partial_dbias,
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; my_output_c_tile, current_place, stride, scale,
#pragma unroll amax, dbias_shfl_src_lane, valid_store);
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j]; my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
} current_stride += nvec_out * stride;
} current_row += nvec_out;
if (my_id_in_warp < tile_length) { }
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
} for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
} }
}
if (my_id_in_warp < tile_length) {
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
} }
}
/* warp tile amax reduce*/ /* warp tile amax reduce*/
amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id); amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) { if (param.amax != nullptr) {
atomicMaxFloat(param.amax, amax); atomicMaxFloat(param.amax, amax);
}
} }
}
} }
static const char* ActTypeToString[] = { static const char *ActTypeToString[] = {
"NoAct", // 0 "NoAct", // 0
"Sigmoid", // 1 "Sigmoid", // 1
"GeLU", // 2 "GeLU", // 2
...@@ -484,1021 +454,915 @@ static const char* ActTypeToString[] = { ...@@ -484,1021 +454,915 @@ static const char* ActTypeToString[] = {
"SReLU" // 6 "SReLU" // 6
}; };
template <typename ComputeType, typename ParamOP, template <typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
ComputeType (*OP)(ComputeType, const ParamOP&)>
int get_dactivation_type() { int get_dactivation_type() {
if (OP == &sigmoid<ComputeType, ComputeType>) { if (OP == &sigmoid<ComputeType, ComputeType>) {
return 1; return 1;
} else if (OP == &dgelu<ComputeType, ComputeType>) { } else if (OP == &dgelu<ComputeType, ComputeType>) {
return 2; return 2;
} else if (OP == &dqgelu<ComputeType, ComputeType>) { } else if (OP == &dqgelu<ComputeType, ComputeType>) {
return 3; return 3;
} else if (OP == &dsilu<ComputeType, ComputeType>) { } else if (OP == &dsilu<ComputeType, ComputeType>) {
return 4; return 4;
} else if (OP == &drelu<ComputeType, ComputeType>) { } else if (OP == &drelu<ComputeType, ComputeType>) {
return 5; return 5;
} else if (OP == &dsrelu<ComputeType, ComputeType>) { } else if (OP == &dsrelu<ComputeType, ComputeType>) {
return 6; return 6;
} else { } else {
return 0; return 0;
} }
} }
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)> ComputeType (*OP)(ComputeType, const ParamOP &)>
void cast_transpose_fused(const Tensor &input, void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output,
const Tensor &act_input, Tensor *transposed_output, Tensor *dbias, Tensor *workspace,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_fused_input"); CheckInputTensor(input, "cast_transpose_fused_input");
CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output"); CheckOutputTensor(*transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape, NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape."); "Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0]; const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type."); "C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor."); "C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor."); "C and T outputs need to share scale tensor.");
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
CheckOutputTensor(*dbias, "dbias"); CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
"DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{row_length}, "Wrong shape of DBias.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, }
"Wrong shape of DBias."); if constexpr (IS_DACT) {
} CheckInputTensor(act_input, "act_input");
if constexpr (IS_DACT) { NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
CheckInputTensor(act_input, "act_input"); NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match."); }
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
} TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, cast_output->data.dtype, OutputType, using InputType2 = InputType;
using InputType2 = InputType; using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>;
using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>;
constexpr int itype_size = sizeof(InputType);
constexpr int itype_size = sizeof(InputType); constexpr int itype2_size = sizeof(InputType2);
constexpr int itype2_size = sizeof(InputType2); constexpr int otype_size = sizeof(OutputType);
constexpr int otype_size = sizeof(OutputType);
const bool aligned =
const bool aligned = (row_length % THREADS_PER_WARP == 0) (row_length % THREADS_PER_WARP == 0) && (num_rows % THREADS_PER_WARP == 0);
&& (num_rows % THREADS_PER_WARP == 0); const bool jit_compiled = aligned && rtc::is_enabled();
const bool jit_compiled = aligned && rtc::is_enabled();
size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size);
size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size); size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size);
size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size); size_t num_blocks;
size_t num_blocks;
if (jit_compiled) {
if (jit_compiled) { // Pick kernel config
// Pick kernel config std::vector<KernelConfig> kernel_configs;
std::vector<KernelConfig> kernel_configs; kernel_configs.reserve(16);
kernel_configs.reserve(16); auto add_config = [&](size_t load_size_config, size_t store_size_config) {
auto add_config = [&](size_t load_size_config, size_t store_size_config) { kernel_configs.emplace_back(row_length, num_rows, itype_size, itype2_size, otype_size,
kernel_configs.emplace_back(row_length, num_rows, load_size_config, store_size_config, IS_DACT);
itype_size, itype2_size, otype_size, };
load_size_config, store_size_config, add_config(8, 8);
IS_DACT); add_config(4, 8);
}; add_config(8, 4);
add_config(8, 8); add_config(4, 4);
add_config(4, 8); add_config(8, 4); add_config(2, 8);
add_config(4, 4); add_config(8, 2);
add_config(2, 8); add_config(8, 2); add_config(2, 4);
add_config(2, 4); add_config(4, 2); add_config(4, 2);
add_config(2, 2); add_config(2, 2);
add_config(1, 8); add_config(8, 1); add_config(1, 8);
add_config(1, 4); add_config(4, 1); add_config(8, 1);
add_config(1, 2); add_config(2, 1); add_config(1, 4);
add_config(1, 1); add_config(4, 1);
add_config(1, 2);
// Select the kernel configuration with the lowest cost add_config(2, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(), add_config(1, 1);
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config"); // Select the kernel configuration with the lowest cost
load_size = kernel_config.load_size; const auto &kernel_config =
store_size = kernel_config.store_size; *std::min_element(kernel_configs.begin(), kernel_configs.end());
num_blocks = kernel_config.num_blocks; NVTE_CHECK(kernel_config.valid, "invalid kernel config");
load_size = kernel_config.load_size;
store_size = kernel_config.store_size;
num_blocks = kernel_config.num_blocks;
}
const size_t nvec_in = load_size / itype_size;
const size_t nvec_out = store_size / otype_size;
const size_t tile_size_x = nvec_in * threads_per_warp;
const size_t tile_size_y = nvec_out * threads_per_warp;
const size_t num_tiles_x = DIVUP(row_length, tile_size_x);
const size_t num_tiles_y = DIVUP(num_rows, tile_size_y);
const size_t num_tiles = num_tiles_x * num_tiles_y;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
if (!jit_compiled) {
num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block);
} if constexpr (IS_DBIAS) {
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
} }
}
const size_t nvec_in = load_size / itype_size;
const size_t nvec_out = store_size / otype_size; size_t VecOutputTypeSize;
const size_t tile_size_x = nvec_in * threads_per_warp; switch (nvec_out) {
const size_t tile_size_y = nvec_out * threads_per_warp; case 1:
const size_t num_tiles_x = DIVUP(row_length, tile_size_x); VecOutputTypeSize = sizeof(Vec<OutputType, 1>);
const size_t num_tiles_y = DIVUP(num_rows, tile_size_y); break;
const size_t num_tiles = num_tiles_x * num_tiles_y; case 2:
VecOutputTypeSize = sizeof(Vec<OutputType, 2>);
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); break;
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); case 4:
VecOutputTypeSize = sizeof(Vec<OutputType, 4>);
if (!jit_compiled) { break;
num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); case 8:
} VecOutputTypeSize = sizeof(Vec<OutputType, 8>);
if constexpr (IS_DBIAS) { break;
if (workspace->data.dptr == nullptr) { } size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
populate_cast_transpose_dbias_workspace_config(*cast_output, (threads_per_warp + 1) * VecOutputTypeSize;
workspace, nvec_out);
return; if constexpr (IS_DBIAS) {
} size_t VecComputeTypeSize;
} switch (nvec_in) {
case 1:
size_t VecOutputTypeSize; VecComputeTypeSize = sizeof(Vec<ComputeType, 1>);
switch (nvec_out) { break;
case 1: VecOutputTypeSize = sizeof(Vec<OutputType, 1>); break; case 2:
case 2: VecOutputTypeSize = sizeof(Vec<OutputType, 2>); break; VecComputeTypeSize = sizeof(Vec<ComputeType, 2>);
case 4: VecOutputTypeSize = sizeof(Vec<OutputType, 4>); break; break;
case 8: VecOutputTypeSize = sizeof(Vec<OutputType, 8>); break; case 4:
VecComputeTypeSize = sizeof(Vec<ComputeType, 4>);
break;
case 8:
VecComputeTypeSize = sizeof(Vec<ComputeType, 8>);
break;
} }
size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize;
* (threads_per_warp + 1) * VecOutputTypeSize; if (shared_size_transpose < shared_size_dbias) {
shared_size_transpose = shared_size_dbias;
if constexpr (IS_DBIAS) {
size_t VecComputeTypeSize;
switch (nvec_in) {
case 1: VecComputeTypeSize = sizeof(Vec<ComputeType, 1>); break;
case 2: VecComputeTypeSize = sizeof(Vec<ComputeType, 2>); break;
case 4: VecComputeTypeSize = sizeof(Vec<ComputeType, 4>); break;
case 8: VecComputeTypeSize = sizeof(Vec<ComputeType, 8>); break;
}
const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize;
if (shared_size_transpose < shared_size_dbias) {
shared_size_transpose = shared_size_dbias;
}
}
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(transposed_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(transposed_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
} }
}
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(transposed_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(transposed_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
} if constexpr (IS_DACT) {
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
}
// Runtime-compiled tuned kernel
if (jit_compiled) {
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *itype2_name = TypeInfo<InputType2>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
int dActType = 0;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr); dActType = get_dactivation_type<ComputeType, ParamOP, OP>();
} }
// Runtime-compiled tuned kernel // Compile NVRTC kernel if needed and launch
if (jit_compiled) { auto &rtc_manager = rtc::KernelManager::instance();
constexpr const char *itype_name = TypeInfo<InputType>::name; const std::string kernel_label = concat_strings(
constexpr const char *itype2_name = TypeInfo<InputType2>::name; "cast_transpose_fusion"
constexpr const char *otype_name = TypeInfo<OutputType>::name; ",itype=",
itype_name, ",itype2=", itype2_name, ",otype=", otype_name,
int dActType = 0; ",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS,
if constexpr (IS_DACT) { ",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]);
dActType = get_dactivation_type<ComputeType, ParamOP, OP>();
} if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu;
// Compile NVRTC kernel if needed and launch code = regex_replace(code, "__ITYPE__", itype_name);
auto& rtc_manager = rtc::KernelManager::instance(); code = regex_replace(code, "__ITYPE2__", itype2_name);
const std::string kernel_label = code = regex_replace(code, "__OTYPE__", otype_name);
concat_strings("cast_transpose_fusion" code = regex_replace(code, "__LOAD_SIZE__", load_size);
",itype=", itype_name, code = regex_replace(code, "__STORE_SIZE__", store_size);
",itype2=", itype2_name, code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile);
",otype=", otype_name, code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads);
",load_size=", load_size, code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS);
",store_size=", store_size, code = regex_replace(code, "__IS_DACT__", IS_DACT);
",IS_DBIAS=", IS_DBIAS, code = regex_replace(code, "__DACTIVATION_TYPE__", dActType);
",IS_DACT=", IS_DACT,
",dactivationType=", ActTypeToString[dActType]); rtc_manager.compile(
kernel_label, "cast_transpose_fusion_kernel_optimized", code,
if (!rtc_manager.is_compiled(kernel_label)) { "transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu");
std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu;
code = regex_replace(code, "__ITYPE__", itype_name);
code = regex_replace(code, "__ITYPE2__", itype2_name);
code = regex_replace(code, "__OTYPE__", otype_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads);
code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS);
code = regex_replace(code, "__IS_DACT__", IS_DACT);
code = regex_replace(code, "__DACTIVATION_TYPE__", dActType);
rtc_manager.compile(
kernel_label,
"cast_transpose_fusion_kernel_optimized",
code,
"transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu");
}
rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED);
rtc_manager.launch(kernel_label,
num_blocks, cast_transpose_num_threads, shared_size_transpose, stream,
param, row_length, num_rows, num_tiles);
} else { // Statically-compiled general kernel
constexpr size_t load_size = IS_DACT ? desired_load_size_dact :
desired_load_size;
constexpr size_t store_size = IS_DACT ? desired_store_size_dact :
desired_store_size;
constexpr size_t nvec_in = load_size / itype_size;
constexpr size_t nvec_out = store_size / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>
(param, row_length, num_rows, num_tiles);
} }
if constexpr (IS_DBIAS) { rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED);
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
} rtc_manager.launch(kernel_label, num_blocks, cast_transpose_num_threads,
); // NOLINT(*) shared_size_transpose, stream, param, row_length, num_rows,
); // NOLINT(*) num_tiles);
} else { // Statically-compiled general kernel
constexpr size_t load_size = IS_DACT ? desired_load_size_dact : desired_load_size;
constexpr size_t store_size = IS_DACT ? desired_store_size_dact : desired_store_size;
constexpr size_t nvec_in = load_size / itype_size;
constexpr size_t nvec_out = store_size / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param,
nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in,
nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, num_tiles);
}
if constexpr (IS_DBIAS) {
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
}); // NOLINT(*)
); // NOLINT(*)
} }
template <int nvec_in, int nvec_out, template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType,
typename CType, typename IType, typename OType, typename ParamOP, typename ParamOP, CType (*OP1)(CType, const ParamOP &),
CType (*OP1)(CType, const ParamOP&), CType (*OP2)(CType, const ParamOP &)>
CType (*OP2)(CType, const ParamOP&)> __global__ void __launch_bounds__(cast_transpose_num_threads)
__global__ void dgated_act_cast_transpose_kernel(const IType *const input, const IType *const act_input,
__launch_bounds__(cast_transpose_num_threads) OType *const output_c, OType *const output_t,
dgated_act_cast_transpose_kernel(const IType * const input, const CType *const scale_ptr, CType *const amax,
const IType * const act_input, CType *const scale_inv, const size_t row_length,
OType * const output_c, const size_t num_rows, const size_t num_tiles) {
OType * const output_t, using IVec = Vec<IType, nvec_in>;
const CType * const scale_ptr, using OVec = Vec<OType, nvec_out>;
CType * const amax, using CVec = Vec<CType, nvec_in>;
CType * const scale_inv,
const size_t row_length, extern __shared__ char scratch[];
const size_t num_rows,
const size_t num_tiles) { const int warp_id = threadIdx.x / THREADS_PER_WARP;
using IVec = Vec<IType, nvec_in>; const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
using OVec = Vec<OType, nvec_out>; const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
using CVec = Vec<CType, nvec_in>; const size_t tile_id =
blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
extern __shared__ char scratch[]; if (tile_id >= num_tiles) {
return;
const int warp_id = threadIdx.x / THREADS_PER_WARP; }
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) const size_t tile_id_y = tile_id / num_tiles_x;
+ warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) { const IType *const my_input_tile =
return; input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
const IType *const my_act_input_tile =
act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
const IType *const my_gate_input_tile =
act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
row_length;
OType *const my_output_c_tile_0 =
output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
OType *const my_output_c_tile_1 =
output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
row_length;
OType *const my_output_t_tile_0 =
output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
OType *const my_output_t_tile_1 =
output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP +
row_length * num_rows;
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
const size_t stride2 = 2 * row_length / nvec_in;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
act_in[current_in][j].load_from(my_act_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = true;
constexpr bool valid_store = true;
constexpr int dbias_shfl_src_lane = 0;
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_trans_0, partial_dbias,
my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dgate, out_trans_1, partial_dbias,
my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space_0[j][i];
}
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space_1[j][i];
} }
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
}
}
const size_t tile_id_x = tile_id % num_tiles_x; template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType,
const size_t tile_id_y = tile_id / num_tiles_x; typename ParamOP, CType (*OP1)(CType, const ParamOP &),
CType (*OP2)(CType, const ParamOP &)>
const IType * const my_input_tile = input + (tile_id_x * nvec_in + __global__ void __launch_bounds__(cast_transpose_num_threads)
tile_id_y * row_length * nvec_out) * dgated_act_cast_transpose_kernel_notaligned(const IType *const input,
THREADS_PER_WARP; const IType *const act_input, OType *const output_c,
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in + OType *const output_t, const CType *const scale_ptr,
tile_id_y * row_length * 2 * nvec_out) * CType *const amax, CType *const scale_inv,
THREADS_PER_WARP; const size_t row_length, const size_t num_rows,
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in + const size_t num_tiles) {
tile_id_y * row_length * 2 * nvec_out) * using IVec = Vec<IType, nvec_in>;
THREADS_PER_WARP + row_length; using OVec = Vec<OType, nvec_out>;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in + using CVec = Vec<CType, nvec_in>;
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP; extern __shared__ char scratch[];
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) * const int warp_id = threadIdx.x / THREADS_PER_WARP;
THREADS_PER_WARP + row_length; const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out + const size_t num_tiles_x =
tile_id_x * num_rows * nvec_in) * (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
THREADS_PER_WARP; const size_t tile_id =
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out + blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
tile_id_x * num_rows * nvec_in) * if (tile_id >= num_tiles) return;
THREADS_PER_WARP + row_length * num_rows; const size_t tile_id_x = tile_id % num_tiles_x;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) + const size_t tile_id_y = tile_id / num_tiles_x;
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1); const IType *const my_input_tile =
input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
IVec in[2][nvec_out]; const IType *const my_act_input_tile =
IVec act_in[2][nvec_out]; act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
IVec gate_in[2][nvec_out]; const IType *const my_gate_input_tile =
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; row_length;
OVec out_space_0[n_iterations][nvec_in]; OType *const my_output_c_tile_0 =
OVec out_space_1[n_iterations][nvec_in]; output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
OType *const my_output_c_tile_1 =
const size_t stride = row_length / nvec_in; output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
const size_t output_stride = num_rows / nvec_out; row_length;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; OType *const my_output_t_tile_0 =
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
warp_id_in_tile * n_iterations) % OType *const my_output_t_tile_1 =
THREADS_PER_WARP; output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP +
const size_t stride2 = 2 * row_length / nvec_in; row_length * num_rows;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; const size_t stride = row_length / nvec_in;
CType max = 0; const size_t stride2 = 2 * row_length / nvec_in;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
CVec partial_dbias; const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length =
#pragma unroll row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height =
row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
{
const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
} else {
in[0][i].clear();
act_in[0][i].clear();
gate_in[0][i].clear();
}
} }
#pragma unroll }
for (unsigned int i = 0; i < n_iterations; ++i) { #pragma unroll
const size_t current_place = current_stride2 + my_place; for (unsigned int i = 0; i < n_iterations; ++i) {
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const size_t current_place = current_stride2 + my_place;
const unsigned int current_in = (i + 1) % 2; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
if (i < n_iterations - 1) { const unsigned int current_in = (i + 1) % 2;
#pragma unroll if (i < n_iterations - 1) {
for (unsigned int j = 0; j < nvec_out; ++j) { {
in[current_in][j].load_from(my_input_tile, const bool valid_load =
current_stride + my_place_in + stride * (nvec_out + j)); my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
act_in[current_in][j].load_from(my_act_input_tile, #pragma unroll
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll if (valid_load) {
for (unsigned int k = 0; k < nvec_in; ++k) { in[current_in][j].load_from(my_input_tile,
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * current_stride + my_place_in + stride * (nvec_out + j));
CType(in[current_in ^ 1][j].data.elt[k]) * act_in[current_in][j].load_from(
CType(gate_in[current_in ^ 1][j].data.elt[k]); my_act_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j));
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * gate_in[current_in][j].load_from(
OP2(act_in[current_in ^ 1][j].data.elt[k], {}); my_gate_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j));
} } else {
} in[current_in][j].clear();
OVec out_trans_0[nvec_in]; // NOLINT(*) act_in[current_in][j].clear();
OVec out_trans_1[nvec_in]; // NOLINT(*) gate_in[current_in][j].clear();
}
constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = true;
constexpr bool valid_store = true;
constexpr int dbias_shfl_src_lane = 0;
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
} }
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; }
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
} }
CVec after_dact[nvec_out]; // NOLINT(*)
for (unsigned int i = 0; i < nvec_in; ++i) { CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - #pragma unroll
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; for (unsigned int k = 0; k < nvec_in; ++k) {
} after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
__syncthreads(); CType(in[current_in ^ 1][j].data.elt[k]) *
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % CType(gate_in[current_in ^ 1][j].data.elt[k]);
THREADS_PER_WARP; after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
current_stride = i * output_stride + OP2(act_in[current_in ^ 1][j].data.elt[k], {});
warp_id_in_tile * n_iterations * output_stride * nvec_in; }
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
} }
OVec out_trans_0[nvec_in]; // NOLINT(*)
/* warp tile amax reduce*/ OVec out_trans_1[nvec_in]; // NOLINT(*)
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
constexpr bool IS_DBIAS = false;
if (threadIdx.x == 0) { constexpr bool IS_FULL_TILE = false;
static_assert(std::is_same<CType, float>::value); constexpr int dbias_shfl_src_lane = 0;
if (amax != nullptr) { const bool valid_store =
atomicMaxFloat(amax, max); (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height);
}
if (scale_inv != nullptr) { cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_trans_0, partial_dbias,
reciprocal<float>(scale_inv, scale); my_output_c_tile_0, current_place, stride2,
} scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dgate, out_trans_1, partial_dbias,
my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
} }
} my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
template <int nvec_in, int nvec_out, current_stride2 += nvec_out * stride2;
typename CType, typename IType, typename OType, }
typename ParamOP,
CType (*OP1)(CType, const ParamOP&), for (unsigned int i = 0; i < nvec_in; ++i) {
CType (*OP2)(CType, const ParamOP&)> #pragma unroll
__global__ void for (unsigned int j = 0; j < n_iterations; ++j) {
__launch_bounds__(cast_transpose_num_threads) my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
dgated_act_cast_transpose_kernel_notaligned(const IType * const input, THREADS_PER_WARP] = out_space_0[j][i];
const IType * const act_input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP + row_length * num_rows;
const size_t stride = row_length / nvec_in;
const size_t stride2 = 2 * row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile,
current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile,
current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile,
current_stride2 + my_place + stride2 * i);
} else {
in[0][i].clear();
act_in[0][i].clear();
gate_in[0][i].clear();
}
}
} }
#pragma unroll __syncthreads();
for (unsigned int i = 0; i < n_iterations; ++i) { my_place =
const size_t current_place = current_stride2 + my_place; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
const unsigned int current_in = (i + 1) % 2; for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
if (i < n_iterations - 1) { const bool valid_store = my_place < tile_height;
{ if (valid_store) {
const bool valid_load = my_place_in < tile_length && my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
warp_id_in_tile * n_iterations + i + 1 < tile_height; current_stride + my_place);
#pragma unroll }
for (unsigned int j = 0; j < nvec_out; ++j) { my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
if (valid_load) { current_stride += output_stride * nvec_in;
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
act_in[current_in][j].load_from(my_act_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
} else {
in[current_in][j].clear();
act_in[current_in][j].clear();
gate_in[current_in][j].clear();
}
}
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = false;
constexpr int dbias_shfl_src_lane = 0;
const bool valid_store = (my_place < tile_length)
&& (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
} }
__syncthreads();
for (unsigned int i = 0; i < nvec_in; ++i) { #pragma unroll
#pragma unroll for (unsigned int j = 0; j < n_iterations; ++j) {
for (unsigned int j = 0; j < n_iterations; ++j) { my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
my_scratch[(my_id_in_warp + THREADS_PER_WARP - THREADS_PER_WARP] = out_space_1[j][i];
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; }
} __syncthreads();
__syncthreads(); my_place =
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
THREADS_PER_WARP; current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
current_stride = i * output_stride + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
warp_id_in_tile * n_iterations * output_stride * nvec_in; const bool valid_store = my_place < tile_height;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { if (valid_store) {
const bool valid_store = my_place < tile_height; my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
if (valid_store) { current_stride + my_place);
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, }
current_stride + my_place); my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
} current_stride += output_stride * nvec_in;
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
} }
__syncthreads();
}
/* warp tile amax reduce*/ /* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) { if (amax != nullptr) {
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
} }
if (scale_inv != nullptr) { if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale); reciprocal<float>(scale_inv, scale);
}
} }
}
} }
template <typename ComputeType, typename ParamOP, template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType, const ParamOP &),
ComputeType (*OP1)(ComputeType, const ParamOP&), ComputeType (*OP2)(ComputeType, const ParamOP &)>
ComputeType (*OP2)(ComputeType, const ParamOP&)> void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input,
void dgated_act_cast_transpose(const Tensor &input, Tensor *cast_output, Tensor *transposed_output,
const Tensor &gated_act_input, cudaStream_t stream) {
Tensor *cast_output, CheckInputTensor(input, "dgated_act_cast_transpose_input");
Tensor *transposed_output, CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input");
cudaStream_t stream) { CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output");
CheckInputTensor(input, "dgated_act_cast_transpose_input"); CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output");
CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input");
CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output"); NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); const size_t row_length = input.data.shape[1];
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
const size_t row_length = input.data.shape[1]; NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output.");
const size_t num_rows = input.data.shape[0]; NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output."); NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output."); NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output."); NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output."); NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match."); "C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, "C and T outputs need to share amax tensor.");
"C and T outputs need to have the same type."); NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, "C and T outputs need to share scale tensor.");
"C and T outputs need to share amax tensor."); NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, "C and T outputs need to share scale inverse tensor.");
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
"C and T outputs need to share scale inverse tensor."); input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, cast_output->data.dtype, OutputType, using InputType2 = InputType;
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, /* dact fusion kernel uses more registers */
using InputType2 = InputType; constexpr int desired_load_size_dact = 4;
/* dact fusion kernel uses more registers */ constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType);
constexpr int desired_load_size_dact = 4; constexpr int otype_size = sizeof(OutputType);
constexpr int desired_store_size_dact = 4; constexpr int nvec_in = desired_load_size_dact / itype_size;
constexpr int itype_size = sizeof(InputType); constexpr int nvec_out = desired_store_size_dact / otype_size;
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dact / itype_size; NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
constexpr int nvec_out = desired_store_size_dact / otype_size; NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles =
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_tiles = const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) * const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); num_rows % (nvec_out * THREADS_PER_WARP) == 0;
const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile *
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
num_rows % (nvec_out * THREADS_PER_WARP) == 0; if (full_tile) {
const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * cudaFuncSetAttribute(
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>); dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
if (full_tile) { OutputType, Empty, OP1, OP2>,
cudaFuncSetAttribute( cudaFuncAttributePreferredSharedMemoryCarveout, 100);
dgated_act_cast_transpose_kernel
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType,
cudaFuncAttributePreferredSharedMemoryCarveout, Empty, OP1, OP2>
100); <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
dgated_act_cast_transpose_kernel reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2> reinterpret_cast<OutputType *>(cast_output->data.dptr),
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>( reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const InputType *>(input.data.dptr), reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr), reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
reinterpret_cast<OutputType *>(transposed_output->data.dptr), n_tiles);
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), } else {
reinterpret_cast<fp32 *>(cast_output->amax.dptr), cudaFuncSetAttribute(
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
row_length, num_rows, n_tiles); InputType, OutputType, Empty, OP1, OP2>,
} else { cudaFuncAttributePreferredSharedMemoryCarveout, 100);
cudaFuncSetAttribute( dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType,
dgated_act_cast_transpose_kernel_notaligned OutputType, Empty, OP1, OP2>
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>, <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
cudaFuncAttributePreferredSharedMemoryCarveout, reinterpret_cast<const InputType *>(input.data.dptr),
100); reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
dgated_act_cast_transpose_kernel_notaligned reinterpret_cast<OutputType *>(cast_output->data.dptr),
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2> reinterpret_cast<OutputType *>(transposed_output->data.dptr),
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>( reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<const InputType *>(input.data.dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr), reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
reinterpret_cast<OutputType *>(cast_output->data.dptr), n_tiles);
reinterpret_cast<OutputType *>(transposed_output->data.dptr), }); // NOLINT(*)
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), ); // NOLINT(*)
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace } // namespace
} // namespace transformer_engine } // namespace transformer_engine
using ComputeType = typename transformer_engine::fp32; using ComputeType = typename transformer_engine::fp32;
void nvte_cast_transpose_dbias(const NVTETensor input, void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output,
NVTETensor cast_output, NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias); NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
constexpr bool IS_DBIAS = true; constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false; constexpr bool IS_DACT = false;
constexpr const NVTETensor activation_input = nullptr; constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(activation_input),
*reinterpret_cast<const Tensor*>(activation_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output, NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
NVTETensor dbias, using namespace transformer_engine;
NVTETensor workspace,
cudaStream_t stream) { constexpr bool IS_DBIAS = true;
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); constexpr bool IS_DACT = true;
using namespace transformer_engine;
constexpr auto dActivation = &dgelu<fp32, fp32>;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true; cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(act_input),
constexpr auto dActivation = &dgelu<fp32, fp32>; reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input,
const NVTETensor silu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output, NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu);
NVTETensor dbias, using namespace transformer_engine;
NVTETensor workspace,
cudaStream_t stream) { constexpr bool IS_DBIAS = true;
NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); constexpr bool IS_DACT = true;
using namespace transformer_engine;
constexpr auto dActivation = &dsilu<fp32, fp32>;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true; cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(silu_input),
constexpr auto dActivation = &dsilu<fp32, fp32>; reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(silu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input,
const NVTETensor relu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output, NVTE_API_CALL(nvte_cast_transpose_dbias_drelu);
NVTETensor dbias, using namespace transformer_engine;
NVTETensor workspace,
cudaStream_t stream) { constexpr bool IS_DBIAS = true;
NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); constexpr bool IS_DACT = true;
using namespace transformer_engine;
constexpr auto dActivation = &drelu<fp32, fp32>;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true; cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(relu_input),
constexpr auto dActivation = &drelu<fp32, fp32>; reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(relu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input,
const NVTETensor srelu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output, NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu);
NVTETensor dbias, using namespace transformer_engine;
NVTETensor workspace,
cudaStream_t stream) { constexpr bool IS_DBIAS = true;
NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); constexpr bool IS_DACT = true;
using namespace transformer_engine;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true; cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(srelu_input),
constexpr auto dActivation = &dsrelu<fp32, fp32>; reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(srelu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input,
const NVTETensor qgelu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output, NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu);
NVTETensor dbias, using namespace transformer_engine;
NVTETensor workspace,
cudaStream_t stream) { constexpr bool IS_DBIAS = true;
NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); constexpr bool IS_DACT = true;
using namespace transformer_engine;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true; cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(qgelu_input),
constexpr auto dActivation = &dqgelu<fp32, fp32>; reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(qgelu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_dgeglu_cast_transpose(const NVTETensor input, void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose); NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
constexpr auto dActivation = &dgelu<fp32, fp32>; constexpr auto dActivation = &dgelu<fp32, fp32>;
constexpr auto Activation = &gelu<fp32, fp32>; constexpr auto Activation = &gelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
void nvte_dswiglu_cast_transpose(const NVTETensor input, void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input,
const NVTETensor swiglu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu_cast_transpose); NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
constexpr auto dActivation = &dsilu<fp32, fp32>; constexpr auto dActivation = &dsilu<fp32, fp32>;
constexpr auto Activation = &silu<fp32, fp32>; constexpr auto Activation = &silu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(swiglu_input),
*reinterpret_cast<const Tensor*>(swiglu_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
void nvte_dreglu_cast_transpose(const NVTETensor input, void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu_cast_transpose); NVTE_API_CALL(nvte_dreglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
constexpr auto dActivation = &drelu<fp32, fp32>; constexpr auto dActivation = &drelu<fp32, fp32>;
constexpr auto Activation = &relu<fp32, fp32>; constexpr auto Activation = &relu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
void nvte_dsreglu_cast_transpose(const NVTETensor input, void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu_cast_transpose); NVTE_API_CALL(nvte_dsreglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
constexpr auto dActivation = &dsrelu<fp32, fp32>; constexpr auto dActivation = &dsrelu<fp32, fp32>;
constexpr auto Activation = &srelu<fp32, fp32>; constexpr auto Activation = &srelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
void nvte_dqgeglu_cast_transpose(const NVTETensor input, void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu_cast_transpose); NVTE_API_CALL(nvte_dqgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
constexpr auto dActivation = &dqgelu<fp32, fp32>; constexpr auto dActivation = &dqgelu<fp32, fp32>;
constexpr auto Activation = &qgelu<fp32, fp32>; constexpr auto Activation = &qgelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <transformer_engine/transpose.h>
#include <cfloat> #include <cfloat>
#include <iostream>
#include <vector> #include <vector>
#include "../utils.cuh"
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -40,21 +42,14 @@ struct MultiCastTransposeArgs { ...@@ -40,21 +42,14 @@ struct MultiCastTransposeArgs {
int row_length_list[kMaxTensorsPerKernel]; int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each // Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor // tensor
int block_range[kMaxTensorsPerKernel+1]; int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel // Number of tensors being processed by kernel
int num_tensors; int num_tensors;
}; };
template < template <int nvec_in, int nvec_out, bool aligned, typename CType, typename IType, typename OType>
int nvec_in, __global__ void __launch_bounds__(threads_per_block)
int nvec_out, multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
bool aligned,
typename CType,
typename IType,
typename OType>
__global__ void
__launch_bounds__(threads_per_block)
multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
using IVec = Vec<IType, nvec_in>; using IVec = Vec<IType, nvec_in>;
using OVecC = Vec<OType, nvec_in>; using OVecC = Vec<OType, nvec_in>;
using OVecT = Vec<OType, nvec_out>; using OVecT = Vec<OType, nvec_out>;
...@@ -79,7 +74,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -79,7 +74,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
// Find tensor corresponding to block // Find tensor corresponding to block
int tensor_id = 0; int tensor_id = 0;
while (args.block_range[tensor_id+1] <= bid) { while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id; ++tensor_id;
} }
const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]); const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]);
...@@ -104,11 +99,11 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -104,11 +99,11 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
// type, and transposes in registers. // type, and transposes in registers.
OVecT local_output_t[nvec_in][n_iterations]; OVecT local_output_t[nvec_in][n_iterations];
CType local_amax = 0; CType local_amax = 0;
#pragma unroll #pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) { for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy; const int i1 = tidy + iter * bdimy;
const int j1 = tidx; const int j1 = tidx;
#pragma unroll #pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) { for (int i2 = 0; i2 < nvec_out; ++i2) {
const int row = tile_row + i1 * nvec_out + i2; const int row = tile_row + i1 * nvec_out + i2;
const int col = tile_col + j1 * nvec_in; const int col = tile_col + j1 * nvec_in;
...@@ -119,7 +114,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -119,7 +114,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
} else { } else {
local_input.clear(); local_input.clear();
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2]; local_input.data.elt[j2] = input[row * row_length + col + j2];
...@@ -127,7 +122,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -127,7 +122,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
} }
} }
} }
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
const CType x = CType(local_input.data.elt[j2]); const CType x = CType(local_input.data.elt[j2]);
const OType y = OType(scale * x); const OType y = OType(scale * x);
...@@ -140,7 +135,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -140,7 +135,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
local_output_c.store_to(&output_c[row * row_length + col]); local_output_c.store_to(&output_c[row * row_length + col]);
} else { } else {
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
output_c[row * row_length + col + j2] = local_output_c.data.elt[j2]; output_c[row * row_length + col + j2] = local_output_c.data.elt[j2];
...@@ -152,17 +147,17 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -152,17 +147,17 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
} }
// Copy transposed output from registers to global memory // Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) { for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy; const int i1 = tidy + iter * bdimy;
const int j1 = tidx; const int j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter]; shared_output_t[j1][i1] = local_output_t[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) { for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidx; const int i1 = tidx;
const int j1 = tidy + iter * bdimy; const int j1 = tidy + iter * bdimy;
...@@ -172,7 +167,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -172,7 +167,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
} else { } else {
if (col < row_length) { if (col < row_length) {
#pragma unroll #pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) { for (int i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) { if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
...@@ -196,8 +191,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -196,8 +191,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
void multi_cast_transpose(const std::vector<Tensor*> input_list, void multi_cast_transpose(const std::vector<Tensor*> input_list,
std::vector<Tensor*> cast_output_list, std::vector<Tensor*> cast_output_list,
std::vector<Tensor*> transposed_output_list, std::vector<Tensor*> transposed_output_list, cudaStream_t stream) {
cudaStream_t stream) {
// Check that number of tensors is valid // Check that number of tensors is valid
NVTE_CHECK(cast_output_list.size() == input_list.size(), NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match"); "Number of input and C output tensors must match");
...@@ -218,15 +212,11 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -218,15 +212,11 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id)); CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id));
CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id)); CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == itype, NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match.");
"Input tensor types do not match."); NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match.");
NVTE_CHECK(cast_output.data.dtype == otype, NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match.");
"C output tensor types do not match.");
NVTE_CHECK(transposed_output.data.dtype == otype,
"T output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
"Input tensor must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape == input.data.shape, NVTE_CHECK(cast_output.data.shape == input.data.shape,
"C output tensor shape does not match input tensor."); "C output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.data.shape.size() == 2, NVTE_CHECK(transposed_output.data.shape.size() == 2,
...@@ -251,27 +241,28 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -251,27 +241,28 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full // Launch kernel if argument struct is full
if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) { if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
constexpr int nvec_out = desired_store_size / sizeof(OutputType); otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; constexpr int nvec_out = desired_store_size / sizeof(OutputType);
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned); multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
kernel_args_aligned.num_tensors = 0; kernel_args_aligned.num_tensors = 0;
} }
if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
constexpr int nvec_out = desired_store_size / sizeof(OutputType); otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; constexpr int nvec_out = desired_store_size / sizeof(OutputType);
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> const int n_blocks =
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned); kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
); // NOLINT(*) multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*)
kernel_args_unaligned.num_tensors = 0; kernel_args_unaligned.num_tensors = 0;
} }
...@@ -283,8 +274,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -283,8 +274,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
const int num_tiles = num_tiles_m * num_tiles_n; const int num_tiles = num_tiles_m * num_tiles_n;
// Figure out whether to use aligned or unaligned kernel // Figure out whether to use aligned or unaligned kernel
const bool aligned = ((num_tiles_m * tile_dim_m == num_rows) const bool aligned =
&& (num_tiles_n * tile_dim_n == row_length)); ((num_tiles_m * tile_dim_m == num_rows) && (num_tiles_n * tile_dim_n == row_length));
auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned; auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned;
// Add tensor to kernel argument struct // Add tensor to kernel argument struct
...@@ -296,53 +287,48 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -296,53 +287,48 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.num_rows_list[pos] = num_rows; kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length; kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++; kernel_args.num_tensors++;
} }
// Launch kernel // Launch kernel
if (kernel_args_aligned.num_tensors > 0) { if (kernel_args_aligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
constexpr int nvec_out = desired_store_size / sizeof(OutputType); otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; constexpr int nvec_out = desired_store_size / sizeof(OutputType);
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned); multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
if (kernel_args_unaligned.num_tensors > 0) { if (kernel_args_unaligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
constexpr int nvec_out = desired_store_size / sizeof(OutputType); otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; constexpr int nvec_out = desired_store_size / sizeof(OutputType);
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> const int n_blocks =
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned); kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
); // NOLINT(*) multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*)
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_multi_cast_transpose(size_t num_tensors, void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
const NVTETensor* input_list, NVTETensor* cast_output_list, NVTETensor* transposed_output_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose); NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
std::vector<Tensor*> input_list_, std::vector<Tensor*> input_list_, cast_output_list_, transposed_output_list_;
cast_output_list_, transposed_output_list_;
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i]))); input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i])); cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i]));
transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i])); transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i]));
} }
multi_cast_transpose(input_list_, multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream);
cast_output_list_,
transposed_output_list_,
stream);
} }
...@@ -21,16 +21,11 @@ constexpr size_t block_size = __BLOCK_SIZE__; ...@@ -21,16 +21,11 @@ constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace } // namespace
__global__ void __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
__launch_bounds__(block_size) const IType* __restrict__ const input, const CType* __restrict__ const noop,
cast_transpose_optimized_kernel(const IType * __restrict__ const input, OType* __restrict__ const output_c, OType* __restrict__ const output_t,
const CType * __restrict__ const noop, const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr,
OType * __restrict__ const output_c, const size_t row_length, const size_t num_rows) {
OType * __restrict__ const output_t,
const CType * __restrict__ const scale_ptr,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
...@@ -73,18 +68,18 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input, ...@@ -73,18 +68,18 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles, computes amax, // Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers. // casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations]; OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
IVec local_input; IVec local_input;
OVecC local_output_c; OVecC local_output_c;
local_input.load_from(&input[row * row_length + col]); local_input.load_from(&input[row * row_length + col]);
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
const CType in = static_cast<CType>(local_input.data.elt[j2]); const CType in = static_cast<CType>(local_input.data.elt[j2]);
const OType out = OType(in * scale); const OType out = OType(in * scale);
...@@ -98,17 +93,17 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input, ...@@ -98,17 +93,17 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input,
} }
// Copy from registers to shared memory to global memory // Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter]; shared_output_t[j1][i1] = local_output_t[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
......
...@@ -4,25 +4,25 @@ ...@@ -4,25 +4,25 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "utils.cuh"
#include "util/math.h" #include "util/math.h"
#include "utils.cuh"
using namespace transformer_engine; using namespace transformer_engine;
namespace { namespace {
// Parameters // Parameters
using CType = float; using CType = float;
using IType = __ITYPE__; using IType = __ITYPE__;
using IType2 = __ITYPE2__; using IType2 = __ITYPE2__;
using OType = __OTYPE__; using OType = __OTYPE__;
constexpr size_t LOAD_SIZE = __LOAD_SIZE__; constexpr size_t LOAD_SIZE = __LOAD_SIZE__;
constexpr size_t STORE_SIZE = __STORE_SIZE__; constexpr size_t STORE_SIZE = __STORE_SIZE__;
constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__; constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__;
constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__; constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__;
constexpr bool IS_DBIAS = __IS_DBIAS__; constexpr bool IS_DBIAS = __IS_DBIAS__;
constexpr bool IS_DACT = __IS_DACT__; constexpr bool IS_DACT = __IS_DACT__;
constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__; constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__;
constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType); constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType);
constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType); constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType);
...@@ -32,218 +32,209 @@ using IVec2 = Vec<IType2, NVEC_IN>; ...@@ -32,218 +32,209 @@ using IVec2 = Vec<IType2, NVEC_IN>;
using OVec = Vec<OType, NVEC_OUT>; using OVec = Vec<OType, NVEC_OUT>;
using Param = CTDBiasDActParam<IType, IType2, OType, CType>; using Param = CTDBiasDActParam<IType, IType2, OType, CType>;
using OP = CType (*)(const CType, const Empty&); using OP = CType (*)(const CType, const Empty &);
constexpr OP Activation[] = { constexpr OP Activation[] = {
nullptr, // 0 nullptr, // 0
&dsigmoid<CType, CType>, // 1 &dsigmoid<CType, CType>, // 1
&dgelu<CType, CType>, // 2 &dgelu<CType, CType>, // 2
&dqgelu<CType, CType>, // 3 &dqgelu<CType, CType>, // 3
&dsilu<CType, CType>, // 4 &dsilu<CType, CType>, // 4
&drelu<CType, CType>, // 5 &drelu<CType, CType>, // 5
&dsrelu<CType, CType> // 6 &dsrelu<CType, CType> // 6
}; };
} // namespace } // namespace
inline __device__ void inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT], OVec (&out_trans)[NVEC_IN],
OVec (&out_trans)[NVEC_IN], CVec &out_dbias, // NOLINT(*)
CVec &out_dbias, // NOLINT(*) typename OVec::type *output_cast_tile,
typename OVec::type *output_cast_tile, const size_t current_place,
const size_t current_place, const size_t stride, const CType scale,
const size_t stride, CType &amax, // NOLINT(*)
const CType scale, const int dbias_shfl_src_lane) {
CType &amax, // NOLINT(*) using OVecC = Vec<OType, NVEC_IN>;
const int dbias_shfl_src_lane) {
using OVecC = Vec<OType, NVEC_IN>; CVec step_dbias;
if constexpr (IS_DBIAS) {
CVec step_dbias; step_dbias.clear();
if constexpr (IS_DBIAS) { }
step_dbias.clear();
} #pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
#pragma unroll OVecC out_cast;
for (unsigned int i = 0; i < NVEC_OUT; ++i) { #pragma unroll
OVecC out_cast; for (unsigned int j = 0; j < NVEC_IN; ++j) {
#pragma unroll const CType tmp = in[i].data.elt[j];
for (unsigned int j = 0; j < NVEC_IN; ++j) { if constexpr (IS_DBIAS) {
const CType tmp = in[i].data.elt[j]; step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
if constexpr (IS_DBIAS) { }
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
} out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose __builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
}
out_cast.store_to(output_cast_tile, current_place + stride * i);
} }
out_cast.store_to(output_cast_tile, current_place + stride * i);
if constexpr (IS_DBIAS) { }
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) { if constexpr (IS_DBIAS) {
CType elt = step_dbias.data.elt[j]; #pragma unroll
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp for (unsigned int j = 0; j < NVEC_IN; ++j) {
out_dbias.data.elt[j] += elt; CType elt = step_dbias.data.elt[j];
} elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
out_dbias.data.elt[j] += elt;
} }
}
} }
__global__ void __global__ void __launch_bounds__(BLOCK_SIZE)
__launch_bounds__(BLOCK_SIZE) cast_transpose_fusion_kernel_optimized(const Param param, const size_t row_length,
cast_transpose_fusion_kernel_optimized(const Param param, const size_t num_rows, const size_t num_tiles) {
const size_t row_length, extern __shared__ char scratch[];
const size_t num_rows,
const size_t num_tiles) { const int warp_id = threadIdx.x / THREADS_PER_WARP;
extern __shared__ char scratch[]; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP);
const int warp_id = threadIdx.x / THREADS_PER_WARP; const size_t tile_id =
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) + warp_id / WARPS_PER_TILE;
const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP); if (tile_id >= num_tiles) {
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) return;
+ warp_id / WARPS_PER_TILE; }
if (tile_id >= num_tiles) {
return; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset =
(tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) * THREADS_PER_WARP;
const size_t tile_offset_transp =
(tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) * THREADS_PER_WARP;
const IType *const my_input_tile = param.input + tile_offset;
const IType2 *const my_act_input_tile = param.act_input + tile_offset;
OType *const my_output_c_tile = param.output_c + tile_offset;
OType *const my_output_t_tile = param.output_t + tile_offset_transp;
CType *const my_partial_dbias_tile =
param.workspace + (tile_id_x * (NVEC_IN * THREADS_PER_WARP) + tile_id_y * row_length);
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][NVEC_OUT];
IVec2 act_in[2][NVEC_OUT];
const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE;
constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE;
OVec out_space[n_iterations][NVEC_IN];
const size_t stride = row_length / NVEC_IN;
const size_t output_stride = num_rows / NVEC_OUT;
size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
CType amax = 0.0f;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i);
} }
}
const size_t tile_id_x = tile_id % num_tiles_x; #pragma unroll
const size_t tile_id_y = tile_id / num_tiles_x; for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const size_t tile_offset = (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
* THREADS_PER_WARP; const unsigned int current_in = (i + 1) % 2;
const size_t tile_offset_transp = (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) if (i < n_iterations - 1) {
* THREADS_PER_WARP; #pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
const IType * const my_input_tile = param.input + tile_offset; const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j);
const IType2 * const my_act_input_tile = param.act_input + tile_offset; in[current_in][j].load_from(my_input_tile, ld_offset);
OType * const my_output_c_tile = param.output_c + tile_offset;
OType * const my_output_t_tile = param.output_t + tile_offset_transp;
CType * const my_partial_dbias_tile = param.workspace
+ (tile_id_x * (NVEC_IN * THREADS_PER_WARP)
+ tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch)
+ (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP)
* (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][NVEC_OUT];
IVec2 act_in[2][NVEC_OUT];
const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE;
constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE;
OVec out_space[n_iterations][NVEC_IN];
const size_t stride = row_length / NVEC_IN;
const size_t output_stride = num_rows / NVEC_OUT;
size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
CType amax = 0.0f;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i); act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
} }
}
} }
#pragma unroll CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*)
for (unsigned int i = 0; i < n_iterations; ++i) { #pragma unroll
const size_t current_place = current_stride + my_place; for (unsigned int j = 0; j < NVEC_OUT; ++j) {
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; #pragma unroll
const unsigned int current_in = (i + 1) % 2; for (unsigned int k = 0; k < NVEC_IN; ++k) {
if (i < n_iterations - 1) { if constexpr (IS_DACT) {
#pragma unroll in_cast_fp32[j].data.elt[k] =
for (unsigned int j = 0; j < NVEC_OUT; ++j) { static_cast<CType>(in[current_in ^ 1][j].data.elt[k]) *
const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j); Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
in[current_in][j].load_from(my_input_tile, ld_offset); } else {
if constexpr (IS_DACT) { in_cast_fp32[j].data.elt[k] = static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
}
}
CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
#pragma unroll
for (unsigned int k = 0; k < NVEC_IN; ++k) {
if constexpr (IS_DACT) {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k])
* Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
} else {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
}
}
} }
}
}
const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations) const int dbias_shfl_src_lane =
% THREADS_PER_WARP; (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, my_output_c_tile,
my_output_c_tile, current_place, current_place, stride, scale, amax, dbias_shfl_src_lane);
stride, scale, amax, dbias_shfl_src_lane);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += NVEC_OUT * stride; current_stride += NVEC_OUT * stride;
current_row += NVEC_OUT; current_row += NVEC_OUT;
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < NVEC_IN; ++i) { for (unsigned int i = 0; i < NVEC_IN; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
- j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
current_stride = i * output_stride
+ warp_id_in_tile * n_iterations * output_stride * NVEC_IN;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * NVEC_IN;
}
__syncthreads();
} }
__syncthreads();
if constexpr (IS_DBIAS) { my_place =
my_dbias_scratch[threadIdx.x] = partial_dbias; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
__syncthreads(); current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * NVEC_IN;
if (warp_id_in_tile == 0) { for (unsigned int j = 0; j < n_iterations; ++j) {
#pragma unroll my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) { current_stride + my_place);
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
#pragma unroll current_stride += output_stride * NVEC_IN;
for (unsigned int j = 0; j < NVEC_IN; ++j) { }
partial_dbias.data.elt[j] += tmp.data.elt[j]; __syncthreads();
} }
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
} }
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
} }
}
// warp tile amax reduce // warp tile amax reduce
const CType max_block = reduce_max<BLOCK_SIZE/THREADS_PER_WARP>(amax, warp_id); const CType max_block = reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (param.amax != nullptr) { if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max_block); atomicMaxFloat(param.amax, max_block);
}
} }
}
} }
...@@ -19,13 +19,10 @@ constexpr size_t block_size = __BLOCK_SIZE__; ...@@ -19,13 +19,10 @@ constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace } // namespace
__global__ void __global__ void __launch_bounds__(block_size)
__launch_bounds__(block_size) transpose_optimized_kernel(const Type* __restrict__ const input, const float* const noop,
transpose_optimized_kernel(const Type * __restrict__ const input, Type* __restrict__ const output, const size_t row_length,
const float * const noop, const size_t num_rows) {
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
...@@ -63,17 +60,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input, ...@@ -63,17 +60,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles and transposes in // Note: Each thread loads num_iterations subtiles and transposes in
// registers. // registers.
OVec local_output[nvec_in][num_iterations]; OVec local_output[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
IVec local_input; IVec local_input;
local_input.load_from(&input[row * row_length + col]); local_input.load_from(&input[row * row_length + col]);
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
} }
...@@ -81,17 +78,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input, ...@@ -81,17 +78,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input,
} }
// Copy from registers to shared memory to global memory // Copy from registers to shared memory to global memory
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output[j1][i1] = local_output[j2][iter]; shared_output[j1][i1] = local_output[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
......
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/cast_transpose_noop.h> #include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <algorithm> #include <algorithm>
#include <cuda_runtime.h>
#include "../common.h" #include "../common.h"
#include "../util/rtc.h" #include "../util/rtc.h"
#include "../util/string.h" #include "../util/string.h"
...@@ -46,24 +45,18 @@ struct KernelConfig { ...@@ -46,24 +45,18 @@ struct KernelConfig {
/* Elements per L1 cache store */ /* Elements per L1 cache store */
size_t elements_per_store = 0; size_t elements_per_store = 0;
KernelConfig(size_t row_length, KernelConfig(size_t row_length, size_t num_rows, size_t type_size, size_t load_size_,
size_t num_rows,
size_t type_size,
size_t load_size_,
size_t store_size_) size_t store_size_)
: load_size{load_size_} : load_size{load_size_}, store_size{store_size_} {
, store_size{store_size_} {
// Check that tiles are correctly aligned // Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128; constexpr size_t cache_line_size = 128;
if (load_size % type_size != 0 if (load_size % type_size != 0 || store_size % type_size != 0 ||
|| store_size % type_size != 0 cache_line_size % type_size != 0) {
|| cache_line_size % type_size != 0) {
return; return;
} }
const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size; const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size; const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size;
valid = (row_length % row_tile_elements == 0 valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
&& num_rows % col_tile_elements == 0);
if (!valid) { if (!valid) {
return; return;
} }
...@@ -75,10 +68,8 @@ struct KernelConfig { ...@@ -75,10 +68,8 @@ struct KernelConfig {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count())); static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) / type_size);
/ type_size); elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) / type_size);
elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size)
/ type_size);
} }
/* Compare by estimated cost */ /* Compare by estimated cost */
...@@ -93,8 +84,8 @@ struct KernelConfig { ...@@ -93,8 +84,8 @@ struct KernelConfig {
const auto &s2 = other.elements_per_store; const auto &s2 = other.elements_per_store;
const auto &p2 = other.active_sm_count; const auto &p2 = other.active_sm_count;
const auto scale = l1 * s1 * p1 * l2 * s2 * p2; const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
const auto cost1 = (scale/l1 + scale/s1) / p1; const auto cost1 = (scale / l1 + scale / s1) / p1;
const auto cost2 = (scale/l2 + scale/s2) / p2; const auto cost2 = (scale / l2 + scale / s2) / p2;
return cost1 < cost2; return cost1 < cost2;
} else { } else {
return this->valid && !other.valid; return this->valid && !other.valid;
...@@ -103,13 +94,10 @@ struct KernelConfig { ...@@ -103,13 +94,10 @@ struct KernelConfig {
}; };
template <size_t load_size, size_t store_size, typename Type> template <size_t load_size, size_t store_size, typename Type>
__global__ void __global__ void __launch_bounds__(block_size)
__launch_bounds__(block_size) transpose_general_kernel(const Type *__restrict__ const input, const fp32 *const noop,
transpose_general_kernel(const Type * __restrict__ const input, Type *__restrict__ const output, const size_t row_length,
const fp32 * const noop, const size_t num_rows) {
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
...@@ -147,25 +135,25 @@ transpose_general_kernel(const Type * __restrict__ const input, ...@@ -147,25 +135,25 @@ transpose_general_kernel(const Type * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles and transposes in // Note: Each thread loads num_iterations subtiles and transposes in
// registers. // registers.
OVec local_output[nvec_in][num_iterations]; OVec local_output[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
IVec local_input; IVec local_input;
local_input.clear(); local_input.clear();
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2]; local_input.data.elt[j2] = input[row * row_length + col + j2];
} }
} }
} }
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
} }
...@@ -173,24 +161,24 @@ transpose_general_kernel(const Type * __restrict__ const input, ...@@ -173,24 +161,24 @@ transpose_general_kernel(const Type * __restrict__ const input,
} }
// Copy transposed output from registers to global memory // Copy transposed output from registers to global memory
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output[j1][i1] = local_output[j2][iter]; shared_output[j1][i1] = local_output[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out; const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2; const size_t col = tile_col + j1 * nvec_in + j2;
if (col < row_length) { if (col < row_length) {
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) { if (row + i2 < num_rows) {
output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2]; output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2];
...@@ -204,10 +192,7 @@ transpose_general_kernel(const Type * __restrict__ const input, ...@@ -204,10 +192,7 @@ transpose_general_kernel(const Type * __restrict__ const input,
} // namespace } // namespace
void transpose(const Tensor &input, void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) {
const Tensor &noop,
Tensor *output_,
cudaStream_t stream) {
Tensor &output = *output_; Tensor &output = *output_;
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions."); NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions.");
...@@ -219,121 +204,106 @@ void transpose(const Tensor &input, ...@@ -219,121 +204,106 @@ void transpose(const Tensor &input,
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == output.data.dtype, NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match.");
"Input and output type must match.");
// Number of elements in tensor // Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t { auto numel = [](const Tensor &tensor) -> size_t {
size_t acc = 1; size_t acc = 1;
for (const auto& dim : tensor.data.shape) { for (const auto &dim : tensor.data.shape) {
acc *= dim; acc *= dim;
} }
return acc; return acc;
}; };
if (noop.data.dptr != nullptr) { if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1, NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), ".");
"Expected 1 element, ",
"but found ", numel(noop), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr); NVTE_CHECK(noop.data.dptr != nullptr);
} }
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
constexpr const char *type_name = TypeInfo<Type>::name; input.data.dtype, Type, constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type); constexpr size_t type_size = sizeof(Type);
// Choose between runtime-compiled or statically-compiled kernel // Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
&& num_rows % THREADS_PER_WARP == 0); if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel // Pick kernel config
// Pick kernel config std::vector<KernelConfig> kernel_configs;
std::vector<KernelConfig> kernel_configs; kernel_configs.reserve(16);
kernel_configs.reserve(16); auto add_config = [&](size_t load_size, size_t store_size) {
auto add_config = [&](size_t load_size, size_t store_size) { kernel_configs.emplace_back(row_length, num_rows, type_size, load_size, store_size);
kernel_configs.emplace_back(row_length, num_rows, type_size, };
load_size, store_size); add_config(8, 8);
}; add_config(4, 8);
add_config(8, 8); add_config(8, 4);
add_config(4, 8); add_config(8, 4); add_config(4, 4);
add_config(4, 4); add_config(2, 8);
add_config(2, 8); add_config(8, 2); add_config(8, 2);
add_config(2, 4); add_config(4, 2); add_config(2, 4);
add_config(2, 2); add_config(4, 2);
add_config(1, 8); add_config(8, 1); add_config(2, 2);
add_config(1, 4); add_config(4, 1); add_config(1, 8);
add_config(1, 2); add_config(2, 1); add_config(8, 1);
add_config(1, 1); add_config(1, 4);
const auto &kernel_config = *std::min_element(kernel_configs.begin(), add_config(4, 1);
kernel_configs.end()); add_config(1, 2);
NVTE_CHECK(kernel_config.valid, "invalid kernel config"); add_config(2, 1);
const size_t load_size = kernel_config.load_size; add_config(1, 1);
const size_t store_size = kernel_config.store_size; const auto &kernel_config = *std::min_element(kernel_configs.begin(), kernel_configs.end());
const size_t num_blocks = kernel_config.num_blocks; NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
// Compile NVRTC kernel if needed and launch const size_t store_size = kernel_config.store_size;
auto& rtc_manager = rtc::KernelManager::instance(); const size_t num_blocks = kernel_config.num_blocks;
const std::string kernel_label = concat_strings("transpose"
",type=", type_name, // Compile NVRTC kernel if needed and launch
",load_size=", load_size, auto &rtc_manager = rtc::KernelManager::instance();
",store_size=", store_size); const std::string kernel_label = concat_strings(
if (!rtc_manager.is_compiled(kernel_label)) { "transpose"
std::string code = string_code_transpose_rtc_transpose_cu; ",type=",
code = regex_replace(code, "__TYPE__", type_name); type_name, ",load_size=", load_size, ",store_size=", store_size);
code = regex_replace(code, "__LOAD_SIZE__", load_size); if (!rtc_manager.is_compiled(kernel_label)) {
code = regex_replace(code, "__STORE_SIZE__", store_size); std::string code = string_code_transpose_rtc_transpose_cu;
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); code = regex_replace(code, "__TYPE__", type_name);
code = regex_replace(code, "__BLOCK_SIZE__", block_size); code = regex_replace(code, "__LOAD_SIZE__", load_size);
rtc_manager.compile(kernel_label, code = regex_replace(code, "__STORE_SIZE__", store_size);
"transpose_optimized_kernel", code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code, code = regex_replace(code, "__BLOCK_SIZE__", block_size);
"transformer_engine/common/transpose/rtc/transpose.cu"); rtc_manager.compile(kernel_label, "transpose_optimized_kernel", code,
} "transformer_engine/common/transpose/rtc/transpose.cu");
rtc_manager.launch(kernel_label, }
num_blocks, block_size, 0, stream, rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
static_cast<const Type *>(input.data.dptr), static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr), static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type*>(output.data.dptr), static_cast<Type *>(output.data.dptr), row_length, num_rows);
row_length, num_rows); } else { // Statically-compiled general kernel
} else { // Statically-compiled general kernel constexpr size_t load_size = 4;
constexpr size_t load_size = 4; constexpr size_t store_size = 4;
constexpr size_t store_size = 4; constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP;
constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP; constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP; const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
const int num_blocks = (DIVUP(row_length, row_tile_size) transpose_general_kernel<load_size, store_size, Type>
* DIVUP(num_rows, col_tile_size)); <<<num_blocks, block_size, 0, stream>>>(static_cast<const Type *>(input.data.dptr),
transpose_general_kernel<load_size, store_size, Type><<<num_blocks, block_size, 0, stream>>>( static_cast<const fp32 *>(noop.data.dptr),
static_cast<const Type *>(input.data.dptr), static_cast<Type *>(output.data.dptr),
static_cast<const fp32 *>(noop.data.dptr), row_length, num_rows);
static_cast<Type *>(output.data.dptr), }); // NOLINT(*)
row_length, num_rows);
}
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_transpose(const NVTETensor input, void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose); NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor *>(input), noop, reinterpret_cast<Tensor *>(output),
noop,
reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
void nvte_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop); NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
*reinterpret_cast<const Tensor*>(noop), reinterpret_cast<Tensor *>(output), stream);
reinterpret_cast<Tensor*>(output),
stream);
} }
...@@ -4,18 +4,19 @@ ...@@ -4,18 +4,19 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <cfloat> #include <cfloat>
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include "../utils.cuh"
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
template <int nvec_in, int nvec_out, template <int nvec_in, int nvec_out, typename IVec, typename OVec, typename CVec, typename CType>
typename IVec, typename OVec, typename CVec, typename CType>
inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in], OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*) CVec &out_dbias, // NOLINT(*)
...@@ -24,7 +25,8 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], ...@@ -24,7 +25,8 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
using T = typename OVec::type; using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>; using OVecC = Vec<T, nvec_in>;
CVec step_dbias; step_dbias.clear(); CVec step_dbias;
step_dbias.clear();
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
...@@ -61,24 +63,21 @@ namespace { ...@@ -61,24 +63,21 @@ namespace {
template <typename IType, typename OType, typename CType> template <typename IType, typename OType, typename CType>
struct TDBiasParam { struct TDBiasParam {
using InputType = IType; using InputType = IType;
using OutputType = OType; using OutputType = OType;
using ComputeType = CType; using ComputeType = CType;
const IType *input; const IType *input;
OType *output_t; OType *output_t;
const CType *scale_inv; const CType *scale_inv;
CType *workspace; CType *workspace;
}; };
} // namespace } // namespace
template <int nvec_in, int nvec_out, typename Param> template <int nvec_in, int nvec_out, typename Param>
__global__ void __global__ void __launch_bounds__(cast_transpose_num_threads)
__launch_bounds__(cast_transpose_num_threads) transpose_dbias_kernel(const Param param, const size_t row_length, const size_t num_rows,
transpose_dbias_kernel(const Param param, const size_t num_tiles) {
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType; using IType = typename Param::InputType;
using OType = typename Param::OutputType; using OType = typename Param::OutputType;
using CType = typename Param::ComputeType; using CType = typename Param::ComputeType;
...@@ -92,27 +91,24 @@ transpose_dbias_kernel(const Param param, ...@@ -92,27 +91,24 @@ transpose_dbias_kernel(const Param param,
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP); // const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id =
warp_id / n_warps_per_tile; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + const IType *const my_input_tile =
tile_id_y * row_length * nvec_out) * param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
THREADS_PER_WARP; OType *const my_output_t_tile =
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
tile_id_x * num_rows * nvec_in) * CType *const my_partial_dbias_tile =
THREADS_PER_WARP; param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) + OVec *const my_scratch =
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * reinterpret_cast<OVec *>(scratch) +
(THREADS_PER_WARP + 1); (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out]; IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
...@@ -123,9 +119,8 @@ transpose_dbias_kernel(const Param param, ...@@ -123,9 +119,8 @@ transpose_dbias_kernel(const Param param,
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - unsigned int my_place =
warp_id_in_tile * n_iterations) % (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
THREADS_PER_WARP;
const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1;
partial_dbias.clear(); partial_dbias.clear();
...@@ -147,11 +142,8 @@ transpose_dbias_kernel(const Param param, ...@@ -147,11 +142,8 @@ transpose_dbias_kernel(const Param param,
} }
OVec out_trans[nvec_in]; // NOLINT(*) OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs_partial_dbias( transpose_regs_partial_dbias(
in[current_in ^ 1], in[current_in ^ 1], out_trans, partial_dbias, scale_inv,
out_trans, (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
partial_dbias,
scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
...@@ -164,14 +156,13 @@ transpose_dbias_kernel(const Param param, ...@@ -164,14 +156,13 @@ transpose_dbias_kernel(const Param param,
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; THREADS_PER_WARP] = out_space[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place); current_stride + my_place);
...@@ -199,12 +190,9 @@ transpose_dbias_kernel(const Param param, ...@@ -199,12 +190,9 @@ transpose_dbias_kernel(const Param param,
} }
template <int nvec_in, int nvec_out, typename Param> template <int nvec_in, int nvec_out, typename Param>
__global__ void __global__ void __launch_bounds__(cast_transpose_num_threads)
__launch_bounds__(cast_transpose_num_threads) transpose_dbias_kernel_notaligned(const Param param, const size_t row_length,
transpose_dbias_kernel_notaligned(const Param param, const size_t num_rows, const size_t num_tiles) {
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType; using IType = typename Param::InputType;
using OType = typename Param::OutputType; using OType = typename Param::OutputType;
using CType = typename Param::ComputeType; using CType = typename Param::ComputeType;
...@@ -216,38 +204,35 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -216,38 +204,35 @@ transpose_dbias_kernel_notaligned(const Param param,
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / const size_t num_tiles_x =
(nvec_in * THREADS_PER_WARP); (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id =
warp_id / n_warps_per_tile; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + const IType *const my_input_tile =
tile_id_y * row_length * nvec_out) * param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
THREADS_PER_WARP; OType *const my_output_t_tile =
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
tile_id_x * num_rows * nvec_in) * CType *const my_partial_dbias_tile =
THREADS_PER_WARP; param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_length =
: row_length_rest; row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_height =
: row_height_rest; row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) + OVec *const my_scratch =
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * reinterpret_cast<OVec *>(scratch) +
(THREADS_PER_WARP + 1); (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out]; IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
...@@ -256,16 +241,14 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -256,16 +241,14 @@ transpose_dbias_kernel_notaligned(const Param param,
CVec partial_dbias; CVec partial_dbias;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - unsigned int my_place =
warp_id_in_tile * n_iterations) % (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
THREADS_PER_WARP;
const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1;
partial_dbias.clear(); partial_dbias.clear();
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) { if (valid_load) {
...@@ -280,8 +263,8 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -280,8 +263,8 @@ transpose_dbias_kernel_notaligned(const Param param,
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length && const bool valid_load =
warp_id_in_tile * n_iterations + i + 1 < tile_height; my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) { if (valid_load) {
...@@ -294,11 +277,8 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -294,11 +277,8 @@ transpose_dbias_kernel_notaligned(const Param param,
} }
OVec out_trans[nvec_in]; // NOLINT(*) OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs_partial_dbias( transpose_regs_partial_dbias(
in[current_in ^ 1], in[current_in ^ 1], out_trans, partial_dbias, scale_inv,
out_trans, (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
partial_dbias,
scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
...@@ -311,14 +291,13 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -311,14 +291,13 @@ transpose_dbias_kernel_notaligned(const Param param,
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; THREADS_PER_WARP] = out_space[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height; const bool valid_store = my_place < tile_height;
if (valid_store) { if (valid_store) {
...@@ -352,27 +331,25 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -352,27 +331,25 @@ transpose_dbias_kernel_notaligned(const Param param,
constexpr size_t reduce_dbias_num_threads = 256; constexpr size_t reduce_dbias_num_threads = 256;
template<int nvec, typename ComputeType, typename OutputType> template <int nvec, typename ComputeType, typename OutputType>
__global__ void __global__ void __launch_bounds__(reduce_dbias_num_threads)
__launch_bounds__(reduce_dbias_num_threads) reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial,
reduce_dbias_kernel(OutputType* const dbias_output, const int row_length, const int num_rows) {
const ComputeType* const dbias_partial,
const int row_length,
const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>; using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>; using OutputVec = Vec<OutputType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= row_length) return; if (thread_id * nvec >= row_length) return;
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec; const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec; OutputType *const thread_out_base = dbias_output + thread_id * nvec;
const int stride_in_vec = row_length / nvec; const int stride_in_vec = row_length / nvec;
ComputeVec ldg_vec; ComputeVec ldg_vec;
ComputeVec acc_vec; acc_vec.clear(); ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < num_rows; ++i) { for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec); ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll #pragma unroll
...@@ -381,7 +358,7 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -381,7 +358,7 @@ reduce_dbias_kernel(OutputType* const dbias_output,
} }
} }
OutputVec stg_vec; OutputVec stg_vec;
#pragma unroll #pragma unroll
for (int e = 0; e < nvec; ++e) { for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
...@@ -390,10 +367,9 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -390,10 +367,9 @@ reduce_dbias_kernel(OutputType* const dbias_output,
} }
void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
Tensor* workspace, Tensor *workspace, const int nvec_out) {
const int nvec_out) {
const size_t row_length = input.data.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0]; const size_t num_rows = input.data.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
...@@ -405,37 +381,28 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ ...@@ -405,37 +381,28 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
} }
template <typename BiasType> template <typename BiasType>
void reduce_dbias(const Tensor &workspace, Tensor *dbias, void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length,
const size_t row_length, const size_t num_rows, const int nvec_out, const size_t num_rows, const int nvec_out, cudaStream_t stream) {
cudaStream_t stream) { constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_store_bytes = 8; // stg.64 constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType);
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType);
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_row_length = row_length; const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows = DIVUP(num_rows, const size_t reduce_dbias_num_rows =
static_cast<size_t>(nvec_out * DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
THREADS_PER_WARP)); const size_t reduce_dbias_num_blocks =
const size_t reduce_dbias_num_blocks = DIVUP(row_length, DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, fp32, BiasType> reduce_dbias_kernel<reduce_dbias_nvec, fp32, BiasType>
<<<reduce_dbias_num_blocks, <<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>(
reduce_dbias_num_threads, reinterpret_cast<BiasType *>(dbias->data.dptr),
0, reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
stream>>>( reduce_dbias_num_rows);
reinterpret_cast<BiasType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr),
reduce_dbias_row_length,
reduce_dbias_num_rows);
} }
void fp8_transpose_dbias(const Tensor &input, void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias,
Tensor *transposed_output, Tensor *workspace, cudaStream_t stream) {
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
CheckInputTensor(input, "fp8_transpose_dbias_input"); CheckInputTensor(input, "fp8_transpose_dbias_input");
CheckOutputTensor(*transposed_output, "transposed_output"); CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias"); CheckOutputTensor(*dbias, "dbias");
...@@ -449,82 +416,71 @@ void fp8_transpose_dbias(const Tensor &input, ...@@ -449,82 +416,71 @@ void fp8_transpose_dbias(const Tensor &input,
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.dtype == input.data.dtype, NVTE_CHECK(transposed_output->data.dtype == input.data.dtype,
"T output must have the same type as input."); "T output must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{row_length}, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dbias->data.dtype, BiasType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, Type, dbias->data.dtype, BiasType,
constexpr int type_size = sizeof(Type); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
constexpr int nvec_in = desired_load_size / type_size; input.data.dtype, Type, constexpr int type_size = sizeof(Type);
constexpr int nvec_out = desired_store_size / type_size; constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size;
if (workspace->data.dptr == nullptr) {
populate_transpose_dbias_workspace_config(input, workspace, nvec_out); if (workspace->data.dptr == nullptr) {
return; populate_transpose_dbias_workspace_config(input, workspace, nvec_out);
} return;
}
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) * NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP)); const size_t n_tiles =
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
using ComputeType = fp32; num_rows % (nvec_out * THREADS_PER_WARP) == 0;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * using ComputeType = fp32; constexpr size_t shared_size_transpose =
sizeof(Vec<Type, nvec_out>); cast_transpose_num_threads / n_warps_per_tile *
constexpr size_t shared_size_dbias = cast_transpose_num_threads * (THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>);
sizeof(Vec<ComputeType, nvec_in>); constexpr size_t shared_size_dbias =
static_assert(shared_size_transpose >= shared_size_dbias); cast_transpose_num_threads * sizeof(Vec<ComputeType, nvec_in>);
using Param = TDBiasParam<Type, Type, ComputeType>; static_assert(shared_size_transpose >= shared_size_dbias);
Param param; using Param = TDBiasParam<Type, Type, ComputeType>; Param param;
param.input = reinterpret_cast<const Type *>(input.data.dptr); param.input = reinterpret_cast<const Type *>(input.data.dptr);
param.output_t = reinterpret_cast<Type *>(transposed_output->data.dptr); param.output_t = reinterpret_cast<Type *>(transposed_output->data.dptr);
param.scale_inv = reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr); param.scale_inv =
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>, if (full_tile) {
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>,
100); cudaFuncAttributePreferredSharedMemoryCarveout, 100);
transpose_dbias_kernel<nvec_in, nvec_out, Param> transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks, <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
cast_transpose_num_threads, param, row_length, num_rows, n_tiles);
shared_size_transpose, } else {
stream>>>(param, row_length, num_rows, n_tiles); cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
} else { cudaFuncAttributePreferredSharedMemoryCarveout, 100);
cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>, transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
cudaFuncAttributePreferredSharedMemoryCarveout, <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
100); param, row_length, num_rows, n_tiles);
transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param> }
<<<n_blocks,
cast_transpose_num_threads, reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out,
shared_size_transpose, stream);); // NOLINT(*)
stream>>>(param, row_length, num_rows, n_tiles); ); // NOLINT(*)
}
reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_fp8_transpose_dbias(const NVTETensor input, void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output,
NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_transpose_dbias); NVTE_API_CALL(nvte_fp8_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
fp8_transpose_dbias(*reinterpret_cast<const Tensor*>(input), fp8_transpose_dbias(
reinterpret_cast<Tensor*>(transposed_output), *reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(dbias), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(workspace),
stream);
} }
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
************************************************************************/ ************************************************************************/
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -15,9 +16,7 @@ namespace detail { ...@@ -15,9 +16,7 @@ namespace detail {
struct Empty {}; struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) { __device__ inline fp32 identity(fp32 value, const Empty &) { return value; }
return value;
}
struct DequantizeParam { struct DequantizeParam {
const fp32 *scale_inv; const fp32 *scale_inv;
...@@ -29,83 +28,63 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) ...@@ -29,83 +28,63 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param)
} // namespace detail } // namespace detail
void fp8_quantize(const Tensor &input, void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "cast_input"); CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output"); CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(!is_fp8_dtype(input.data.dtype), NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision.");
"Input must be in higher precision.");
NVTE_CHECK(is_fp8_dtype(output->data.dtype), NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type.");
"Output must have FP8 type.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(IType); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>( output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
reinterpret_cast<const IType*>(input.data.dptr), VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp32*>(output->amax.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
N, reinterpret_cast<fp32 *>(output->amax.dptr), N, {},
{}, stream);); // NOLINT(*)
stream); ); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
} }
void fp8_dequantize(const Tensor &input, void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "cast_input"); CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output"); CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(is_fp8_dtype(input.data.dtype), NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
"Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
"Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(OType); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
detail::DequantizeParam p; output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType);
p.scale_inv = reinterpret_cast<const fp32*>(input.scale_inv.dptr); detail::DequantizeParam p;
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>( p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
reinterpret_cast<const IType*>(input.data.dptr), VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
nullptr, reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, N, p,
nullptr, stream);); // NOLINT(*)
N, ); // NOLINT(*)
p,
stream);
); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_fp8_quantize(const NVTETensor input, void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize); NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input), fp8_quantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_fp8_dequantize(const NVTETensor input, void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize); NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input), fp8_dequantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
reinterpret_cast<Tensor*>(output),
stream); stream);
} }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
************************************************************************/ ************************************************************************/
#include <dlfcn.h> #include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include "../common.h" #include "../common.h"
...@@ -40,27 +41,21 @@ class Library { ...@@ -40,27 +41,21 @@ class Library {
#endif // _WIN32 or _WIN64 or __WINDOW__ #endif // _WIN32 or _WIN64 or __WINDOW__
} }
Library(const Library&) = delete; // move-only Library(const Library &) = delete; // move-only
Library(Library&& other) noexcept { Library(Library &&other) noexcept { swap(*this, other); }
swap(*this, other);
}
Library& operator=(Library other) noexcept { Library &operator=(Library other) noexcept {
// Copy-and-swap idiom // Copy-and-swap idiom
swap(*this, other); swap(*this, other);
return *this; return *this;
} }
friend void swap(Library& first, Library& second) noexcept; friend void swap(Library &first, Library &second) noexcept;
void *get() noexcept { void *get() noexcept { return handle_; }
return handle_;
}
const void *get() const noexcept { const void *get() const noexcept { return handle_; }
return handle_;
}
/*! \brief Get pointer corresponding to symbol in shared library */ /*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) { void *get_symbol(const char *symbol) {
...@@ -78,13 +73,13 @@ class Library { ...@@ -78,13 +73,13 @@ class Library {
void *handle_ = nullptr; void *handle_ = nullptr;
}; };
void swap(Library& first, Library& second) noexcept { void swap(Library &first, Library &second) noexcept {
using std::swap; using std::swap;
swap(first.handle_, second.handle_); swap(first.handle_, second.handle_);
} }
/*! \brief Lazily-initialized shared library for CUDA driver */ /*! \brief Lazily-initialized shared library for CUDA driver */
Library& cuda_driver_lib() { Library &cuda_driver_lib() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
constexpr char lib_name[] = "nvcuda.dll"; constexpr char lib_name[] = "nvcuda.dll";
#else #else
...@@ -98,9 +93,7 @@ Library& cuda_driver_lib() { ...@@ -98,9 +93,7 @@ Library& cuda_driver_lib() {
namespace cuda_driver { namespace cuda_driver {
void *get_symbol(const char *symbol) { void *get_symbol(const char *symbol) { return cuda_driver_lib().get_symbol(symbol); }
return cuda_driver_lib().get_symbol(symbol);
}
} // namespace cuda_driver } // namespace cuda_driver
......
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#include <string>
#include <cuda.h> #include <cuda.h>
#include <string>
#include "../common.h" #include "../common.h"
#include "../util/string.h" #include "../util/string.h"
...@@ -35,7 +35,7 @@ void *get_symbol(const char *symbol); ...@@ -35,7 +35,7 @@ void *get_symbol(const char *symbol);
template <typename... ArgTs> template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) { inline CUresult call(const char *symbol, ArgTs... args) {
using FuncT = CUresult(ArgTs...); using FuncT = CUresult(ArgTs...);
FuncT *func = reinterpret_cast<FuncT*>(get_symbol(symbol)); FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));
return (*func)(args...); return (*func)(args...);
} }
...@@ -43,23 +43,20 @@ inline CUresult call(const char *symbol, ArgTs... args) { ...@@ -43,23 +43,20 @@ inline CUresult call(const char *symbol, ArgTs... args) {
} // namespace transformer_engine } // namespace transformer_engine
#define NVTE_CHECK_CUDA_DRIVER(expr) \ #define NVTE_CHECK_CUDA_DRIVER(expr) \
do { \ do { \
const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \ const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \
if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \ if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \
const char *desc_NVTE_CHECK_CUDA_DRIVER; \ const char *desc_NVTE_CHECK_CUDA_DRIVER; \
::transformer_engine::cuda_driver::call( \ ::transformer_engine::cuda_driver::call("cuGetErrorString", status_NVTE_CHECK_CUDA_DRIVER, \
"cuGetErrorString", \ &desc_NVTE_CHECK_CUDA_DRIVER); \
status_NVTE_CHECK_CUDA_DRIVER, \ NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \
&desc_NVTE_CHECK_CUDA_DRIVER); \ } \
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \
} \
} while (false) } while (false)
#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \ #define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \
do { \ do { \
NVTE_CHECK_CUDA_DRIVER( \ NVTE_CHECK_CUDA_DRIVER(::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \
::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \
} while (false) } while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../util/cuda_runtime.h"
#include <filesystem> #include <filesystem>
#include <mutex> #include <mutex>
#include "../common.h" #include "../common.h"
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h" #include "../util/system.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -24,7 +25,7 @@ namespace { ...@@ -24,7 +25,7 @@ namespace {
} // namespace } // namespace
int num_devices() { int num_devices() {
auto query_num_devices = [] () -> int { auto query_num_devices = []() -> int {
int count; int count;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&count)); NVTE_CHECK_CUDA(cudaGetDeviceCount(&count));
return count; return count;
...@@ -54,10 +55,10 @@ int sm_arch(int device_id) { ...@@ -54,10 +55,10 @@ int sm_arch(int device_id) {
device_id = current_device(); device_id = current_device();
} }
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () { auto init = [&]() {
cudaDeviceProp prop; cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id)); NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = 10*prop.major + prop.minor; cache[device_id] = 10 * prop.major + prop.minor;
}; };
std::call_once(flags[device_id], init); std::call_once(flags[device_id], init);
return cache[device_id]; return cache[device_id];
...@@ -70,7 +71,7 @@ int sm_count(int device_id) { ...@@ -70,7 +71,7 @@ int sm_count(int device_id) {
device_id = current_device(); device_id = current_device();
} }
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () { auto init = [&]() {
cudaDeviceProp prop; cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id)); NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = prop.multiProcessorCount; cache[device_id] = prop.multiProcessorCount;
...@@ -90,12 +91,11 @@ const std::string &include_directory(bool required) { ...@@ -90,12 +91,11 @@ const std::string &include_directory(bool required) {
if (need_to_check_env) { if (need_to_check_env) {
// Search for CUDA headers in common paths // Search for CUDA headers in common paths
using Path = std::filesystem::path; using Path = std::filesystem::path;
std::vector<std::pair<std::string, Path>> search_paths = { std::vector<std::pair<std::string, Path>> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""},
{"NVTE_CUDA_INCLUDE_DIR", ""}, {"CUDA_HOME", ""},
{"CUDA_HOME", ""}, {"CUDA_DIR", ""},
{"CUDA_DIR", ""}, {"", string_path_cuda_include},
{"", string_path_cuda_include}, {"", "/usr/local/cuda"}};
{"", "/usr/local/cuda"}};
for (auto &[env, p] : search_paths) { for (auto &[env, p] : search_paths) {
if (p.empty()) { if (p.empty()) {
p = getenv<Path>(env.c_str()); p = getenv<Path>(env.c_str());
...@@ -131,10 +131,11 @@ const std::string &include_directory(bool required) { ...@@ -131,10 +131,11 @@ const std::string &include_directory(bool required) {
message += p; message += p;
} }
} }
message += (". " message +=
"Specify path to CUDA Toolkit headers " (". "
"with NVTE_CUDA_INCLUDE_DIR " "Specify path to CUDA Toolkit headers "
"or disable NVRTC support with NVTE_DISABLE_NVRTC=1."); "with NVTE_CUDA_INCLUDE_DIR "
"or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
NVTE_ERROR(message); NVTE_ERROR(message);
} }
need_to_check_env = false; need_to_check_env = false;
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <string> #include <string>
namespace transformer_engine { namespace transformer_engine {
......
...@@ -7,83 +7,76 @@ ...@@ -7,83 +7,76 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#include <stdexcept>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cudnn.h> #include <cudnn.h>
#include <nvrtc.h> #include <nvrtc.h>
#include <stdexcept>
#include "../util/string.h" #include "../util/string.h"
#define NVTE_ERROR(...) \ #define NVTE_ERROR(...) \
do { \ do { \
throw ::std::runtime_error( \ throw ::std::runtime_error(::transformer_engine::concat_strings( \
::transformer_engine::concat_strings( \ __FILE__ ":", __LINE__, " in function ", __func__, ": ", \
__FILE__ ":", __LINE__, \ ::transformer_engine::concat_strings(__VA_ARGS__))); \
" in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__))); \
} while (false) } while (false)
#define NVTE_CHECK(expr, ...) \ #define NVTE_CHECK(expr, ...) \
do { \ do { \
if (!(expr)) { \ if (!(expr)) { \
NVTE_ERROR("Assertion failed: " #expr ". ", \ NVTE_ERROR("Assertion failed: " #expr ". ", \
::transformer_engine::concat_strings(__VA_ARGS__)); \ ::transformer_engine::concat_strings(__VA_ARGS__)); \
} \ } \
} while (false) } while (false)
#define NVTE_CHECK_CUDA(expr) \ #define NVTE_CHECK_CUDA(expr) \
do { \ do { \
const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ const cudaError_t status_NVTE_CHECK_CUDA = (expr); \
if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ if (status_NVTE_CHECK_CUDA != cudaSuccess) { \
NVTE_ERROR("CUDA Error: ", \ NVTE_ERROR("CUDA Error: ", cudaGetErrorString(status_NVTE_CHECK_CUDA)); \
cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ } \
} \
} while (false) } while (false)
#define NVTE_CHECK_CUBLAS(expr) \ #define NVTE_CHECK_CUBLAS(expr) \
do { \ do { \
const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLAS Error: ", \ NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \ } \
} \
} while (false) } while (false)
#define NVTE_CHECK_CUDNN(expr) \ #define NVTE_CHECK_CUDNN(expr) \
do { \ do { \
const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \ const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \
if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \ if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \
NVTE_ERROR("cuDNN Error: ", \ NVTE_ERROR("cuDNN Error: ", cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \
cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \ ". " \
". " \ "For more information, enable cuDNN error logging " \
"For more information, enable cuDNN error logging " \ "by setting CUDNN_LOGERR_DBG=1 and " \
"by setting CUDNN_LOGERR_DBG=1 and " \ "CUDNN_LOGDEST_DBG=stderr in the environment."); \
"CUDNN_LOGDEST_DBG=stderr in the environment."); \ } \
} \
} while (false) } while (false)
#define NVTE_CHECK_CUDNN_FE(expr) \ #define NVTE_CHECK_CUDNN_FE(expr) \
do { \ do { \
const auto error = (expr); \ const auto error = (expr); \
if (error.is_bad()) { \ if (error.is_bad()) { \
NVTE_ERROR("cuDNN Error: ", \ NVTE_ERROR("cuDNN Error: ", error.err_msg, \
error.err_msg, \ ". " \
". " \ "For more information, enable cuDNN error logging " \
"For more information, enable cuDNN error logging " \ "by setting CUDNN_LOGERR_DBG=1 and " \
"by setting CUDNN_LOGERR_DBG=1 and " \ "CUDNN_LOGDEST_DBG=stderr in the environment."); \
"CUDNN_LOGDEST_DBG=stderr in the environment."); \ } \
} \
} while (false) } while (false)
#define NVTE_CHECK_NVRTC(expr) \ #define NVTE_CHECK_NVRTC(expr) \
do { \ do { \
const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \ const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \
if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \ if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \
NVTE_ERROR("NVRTC Error: ", \ NVTE_ERROR("NVRTC Error: ", nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \
nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \ } \
} \
} while (false) } while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
...@@ -13,75 +13,73 @@ struct Empty {}; ...@@ -13,75 +13,73 @@ struct Empty {};
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType gelu(const IType val, const Empty&) { __device__ inline OType gelu(const IType val, const Empty&) {
const float cval = val; const float cval = val;
return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval))); return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dgelu(const IType val, const Empty&) { __device__ inline OType dgelu(const IType val, const Empty&) {
const float cval = val; const float cval = val;
const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
(0.79788456f + 0.1070322243f * cval * cval)) + 0.5f * (1.f + tanh_out);
0.5f * (1.f + tanh_out);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) { __device__ inline OType sigmoid(const IType val, const Empty&) {
const float cval = val; const float cval = val;
return 1.f / (1.f + expf(-cval)); return 1.f / (1.f + expf(-cval));
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) { __device__ inline OType dsigmoid(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
const float s = sigmoid<float, float>(cval, e); const float s = sigmoid<float, float>(cval, e);
return s * (1.f - s); return s * (1.f - s);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) { __device__ inline OType qgelu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e); return cval * sigmoid<float, float>(1.702f * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) { __device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) + return cval * dsigmoid<float, float>(1.702f * cval, e) + sigmoid<float, float>(1.702f * cval, e);
sigmoid<float, float>(1.702f * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType silu(const IType val, const Empty& e) { __device__ inline OType silu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * sigmoid<float, float>(cval, e); return cval * sigmoid<float, float>(cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dsilu(const IType val, const Empty& e) { __device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e); return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty &) { __device__ inline OType relu(IType value, const Empty&) {
return fmaxf(value, 0.f); return fmaxf(value, 0.f);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType drelu(IType value, const Empty &) { __device__ inline OType drelu(IType value, const Empty&) {
return value > 0.f ? 1.f : 0.f; return value > 0.f ? 1.f : 0.f;
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType srelu(IType value, const Empty &) { __device__ inline OType srelu(IType value, const Empty&) {
return value > 0 ? value * value : 0.f; return value > 0 ? value * value : 0.f;
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dsrelu(IType value, const Empty &) { __device__ inline OType dsrelu(IType value, const Empty&) {
return fmaxf(2.f * value, 0.f); return fmaxf(2.f * value, 0.f);
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../util/rtc.h"
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
#include <utility> #include <utility>
...@@ -13,8 +15,6 @@ ...@@ -13,8 +15,6 @@
#include "../util/string.h" #include "../util/string.h"
#include "../util/system.h" #include "../util/system.h"
#include "../util/rtc.h"
namespace transformer_engine { namespace transformer_engine {
namespace rtc { namespace rtc {
...@@ -22,8 +22,8 @@ namespace rtc { ...@@ -22,8 +22,8 @@ namespace rtc {
namespace { namespace {
// Strings with headers for RTC kernels // Strings with headers for RTC kernels
#include "string_code_utils_cuh.h"
#include "string_code_util_math_h.h" #include "string_code_util_math_h.h"
#include "string_code_utils_cuh.h"
/*! \brief Latest compute capability that NVRTC supports /*! \brief Latest compute capability that NVRTC supports
* *
...@@ -56,29 +56,25 @@ bool is_enabled() { ...@@ -56,29 +56,25 @@ bool is_enabled() {
} }
Kernel::Kernel(std::string mangled_name, std::string compiled_code) Kernel::Kernel(std::string mangled_name, std::string compiled_code)
: mangled_name_{std::move(mangled_name)} : mangled_name_{std::move(mangled_name)},
, compiled_code_{std::move(compiled_code)} compiled_code_{std::move(compiled_code)},
, modules_(cuda::num_devices(), null_module) modules_(cuda::num_devices(), null_module),
, functions_(cuda::num_devices(), null_function) functions_(cuda::num_devices(), null_function),
, init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} { init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} {}
}
Kernel::~Kernel() { Kernel::~Kernel() {
for (int device_id=0; device_id<static_cast<int>(modules_.size()); ++device_id) { for (int device_id = 0; device_id < static_cast<int>(modules_.size()); ++device_id) {
// Unload CUDA modules if needed // Unload CUDA modules if needed
if (modules_[device_id] != null_module) { if (modules_[device_id] != null_module) {
CUdevice device; CUdevice device;
CUcontext context; CUcontext context;
if (cuda_driver::call("cuDeviceGet", &device, device_id) if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) {
!= CUDA_SUCCESS) {
continue; continue;
} }
if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) != CUDA_SUCCESS) {
!= CUDA_SUCCESS) {
continue; continue;
} }
if (cuda_driver::call("cuCtxSetCurrent", context) if (cuda_driver::call("cuCtxSetCurrent", context) != CUDA_SUCCESS) {
!= CUDA_SUCCESS) {
continue; continue;
} }
cuda_driver::call("cuModuleUnload", modules_[device_id]); cuda_driver::call("cuModuleUnload", modules_[device_id]);
...@@ -87,9 +83,7 @@ Kernel::~Kernel() { ...@@ -87,9 +83,7 @@ Kernel::~Kernel() {
} }
} }
Kernel::Kernel(Kernel&& other) noexcept { Kernel::Kernel(Kernel&& other) noexcept { swap(*this, other); }
swap(*this, other);
}
Kernel& Kernel::operator=(Kernel other) noexcept { Kernel& Kernel::operator=(Kernel other) noexcept {
// Copy-and-swap idiom // Copy-and-swap idiom
...@@ -108,7 +102,7 @@ void swap(Kernel& first, Kernel& second) noexcept { ...@@ -108,7 +102,7 @@ void swap(Kernel& first, Kernel& second) noexcept {
CUfunction Kernel::get_function(int device_id) { CUfunction Kernel::get_function(int device_id) {
// Load kernel on device if needed // Load kernel on device if needed
auto load_on_device = [&] () { auto load_on_device = [&]() {
// Set driver context to proper device // Set driver context to proper device
CUdevice device; CUdevice device;
CUcontext context; CUcontext context;
...@@ -117,15 +111,11 @@ CUfunction Kernel::get_function(int device_id) { ...@@ -117,15 +111,11 @@ CUfunction Kernel::get_function(int device_id) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
// Load function into driver context // Load function into driver context
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, &modules_[device_id], compiled_code_.c_str(),
&modules_[device_id], 0, // numOptions
compiled_code_.c_str(), nullptr, // options
0, // numOptions nullptr); // optionValues
nullptr, // options NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction, &functions_[device_id], modules_[device_id],
nullptr); // optionValues
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction,
&functions_[device_id],
modules_[device_id],
mangled_name_.c_str()); mangled_name_.c_str());
// Reset driver context // Reset driver context
...@@ -147,10 +137,8 @@ KernelManager& KernelManager::instance() { ...@@ -147,10 +137,8 @@ KernelManager& KernelManager::instance() {
return instance_; return instance_;
} }
void KernelManager::compile(const std::string &kernel_label, void KernelManager::compile(const std::string& kernel_label, const std::string& kernel_name,
const std::string &kernel_name, const std::string& code, const std::string& filename) {
const std::string &code,
const std::string &filename) {
std::lock_guard<std::mutex> lock_guard_(lock_); std::lock_guard<std::mutex> lock_guard_(lock_);
// Choose whether to compile to PTX or cubin // Choose whether to compile to PTX or cubin
...@@ -162,9 +150,9 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -162,9 +150,9 @@ void KernelManager::compile(const std::string &kernel_label,
// Compilation flags // Compilation flags
std::vector<std::string> opts = { std::vector<std::string> opts = {
#if NDEBUG == 0 #if NDEBUG == 0
"-G", "-G",
#endif #endif
"--std=c++17"}; "--std=c++17"};
if (compile_ptx) { if (compile_ptx) {
opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch)); opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch));
} else { } else {
...@@ -181,20 +169,14 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -181,20 +169,14 @@ void KernelManager::compile(const std::string &kernel_label,
constexpr int num_headers = 2; constexpr int num_headers = 2;
constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h}; constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"}; constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers,
code.c_str(), headers, include_names));
filename.c_str(),
num_headers,
headers,
include_names));
NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str())); NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
const nvrtcResult compile_result = nvrtcCompileProgram(program, const nvrtcResult compile_result =
opts_ptrs.size(), nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data());
opts_ptrs.data());
if (compile_result != NVRTC_SUCCESS) { if (compile_result != NVRTC_SUCCESS) {
// Display log if compilation failed // Display log if compilation failed
std::string log = concat_strings("NVRTC compilation log for ", std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n");
filename, ":\n");
const size_t log_offset = log.size(); const size_t log_offset = log.size();
size_t log_size; size_t log_size;
NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size)); NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size));
...@@ -206,10 +188,8 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -206,10 +188,8 @@ void KernelManager::compile(const std::string &kernel_label,
} }
// Get mangled function name // Get mangled function name
const char *mangled_name; const char* mangled_name;
NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, kernel_name.c_str(), &mangled_name));
kernel_name.c_str(),
&mangled_name));
// Get compiled code // Get compiled code
std::string compiled_code; std::string compiled_code;
...@@ -234,20 +214,19 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -234,20 +214,19 @@ void KernelManager::compile(const std::string &kernel_label,
NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program)); NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program));
} }
void KernelManager::set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config) { void KernelManager::set_cache_config(const std::string& kernel_label, CUfunc_cache cache_config) {
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id); const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0, NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to configure RTC kernel before compilation");
"Attempted to configure RTC kernel before compilation");
kernel_cache_.at(key).set_function_cache_config(device_id, cache_config); kernel_cache_.at(key).set_function_cache_config(device_id, cache_config);
} }
bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const { bool KernelManager::is_compiled(const std::string& kernel_label, int device_id) const {
const auto key = get_kernel_cache_key(kernel_label, device_id); const auto key = get_kernel_cache_key(kernel_label, device_id);
return kernel_cache_.count(key) > 0; return kernel_cache_.count(key) > 0;
} }
std::string KernelManager::get_kernel_cache_key(const std::string &kernel_label, std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label,
int device_id) const { int device_id) const {
return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label); return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment