Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -13,167 +13,147 @@ using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits =
rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col *
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::reduce_t) * 2;
}
return;
void launch_tuned_(LaunchParams<BwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits =
rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::reduce_t) * 2;
}
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
}
using Kernel_traits_f =
Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t, output_t, compute_t, index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
return;
}
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(
launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), Kernel_traits::SMEM_BYTES,
stream);
}
using Kernel_traits_f =
Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t, output_t, compute_t, index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t,
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col * WARPS_M * ctas_per_row *
sizeof(typename Kernel_traits::reduce_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
void launch_general_(LaunchParams<BwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>;
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel,
Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t));
auto kernel_final =
&rmsnorm_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t));
auto kernel_final =
&rmsnorm_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \
&launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, \
configure_params); \
} \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -13,122 +13,120 @@ using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t,
HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col *
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::Stats::stats_t) * 2;
}
return;
}
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD,
stream>>>(launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
void launch_tuned_(LaunchParams<FwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_row = CTAS_PER_ROW;
launch_params.params.ctas_per_col =
launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::Stats::stats_t) * 2;
}
return;
}
if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
launch_params.params);
} else {
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, // NOLINT(*)
Kernel_traits::SMEM_BYTES_FWD, stream);
}
}
template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t,
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
void launch_general_(LaunchParams<FwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Configure kernel params
const int rows = launch_params.params.rows;
const int cols = launch_params.params.cols;
int ctas_per_col = launch_params.params.ctas_per_col;
int ctas_per_row = launch_params.params.ctas_per_row;
if (configure_params) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0);
const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm;
ctas_per_row = ceil_div(cols, HIDDEN_SIZE);
ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row);
launch_params.params.ctas_per_row = ctas_per_row;
launch_params.params.ctas_per_col = ctas_per_col;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes =
(ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2);
}
return;
}
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
if (ctas_per_row == 1) {
kernel<<<grid, block, 0, stream>>>(launch_params.params);
} else {
void *params_ = reinterpret_cast<void *>(&launch_params.params);
cudaLaunchCooperativeKernel(reinterpret_cast<void *>(kernel), grid, block,
reinterpret_cast<void **>(&params_), 0, stream);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
WARPS_N, BYTES_PER_LDG>(launch_params, configure_params); \
} \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
BYTES_PER_LDG) \
void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
BYTES_PER_LDG>(launch_params, configure_params); \
} \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -9,6 +9,7 @@
#include <cfloat>
#include <cstdio>
#include "../utils.cuh"
namespace transformer_engine {
......@@ -18,261 +19,260 @@ using namespace transformer_engine;
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel(
FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
index_t idx = c;
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for (int it = 0; it < LDGS; it++) {
gamma[it].load_from(params.gamma, idx);
idx += VEC_COLS_PER_LDG;
x[it].load_from(params.x, idx);
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
stats_t s = stats.compute(xf, rn);
compute_t mu = Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = Get<1>::of<stats_t, compute_t>(s);
// reciprocal of root mean square
// we could optimize here to count mean square directly
compute_t rs = rsqrtf(rn * m2 + mu * mu + params.epsilon);
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
if (bidn == 0 && warp_n == 0 && lane == 0) {
rs_ptr[row] = rs;
}
compute_t amax = 0;
for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) {
Ivec x[LDGS];
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
x[it].load_from(params.x, idx);
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t x_ij = compute_t(x[it].data.elt[jt]);
xf[it * NUM_ELTS + jt] = x_ij;
}
idx += VEC_COLS_PER_LDG;
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t temp_output = g_ij * y_ij;
stats_t s = stats.compute(xf, rn);
compute_t mu = Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = Get<1>::of<stats_t, compute_t>(s);
// reciprocal of root mean square
// we could optimize here to count mean square directly
compute_t rs = rsqrtf(rn * m2 + mu * mu + params.epsilon);
if (bidn == 0 && warp_n == 0 && lane == 0) {
rs_ptr[row] = rs;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
Ovec z[LDGS];
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for (int it = 0; it < LDGS; it++) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (xf[it * NUM_ELTS + jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
compute_t temp_output = g_ij * y_ij;
if (params.fp8_out) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
temp_output = temp_output * scale;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
z[it].data.elt[jt] = output_t(temp_output);
}
z[it].store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0 && threadIdx.y == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
}
template <typename Ktraits>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel(
FwdParams params) {
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn =
(bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for stats reductions
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<compute_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
using input_t = typename Ktraits::input_t;
using weight_t = typename Ktraits::weight_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using Ivec = typename Ktraits::Ivec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
const index_t tidx = threadIdx.x;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t bdimm = WARPS_M;
const index_t bdimn = WARPS_N * THREADS_PER_WARP;
const index_t bidm = blockIdx.x / params.ctas_per_row;
const index_t bidn = blockIdx.x % params.ctas_per_row;
const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane
// Objects for stats reductions
using Reducer = DynamicReducer<compute_t, WARPS_M, WARPS_N>;
constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1;
__shared__ char smem_[SMEM_BYTES];
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_);
Sum<compute_t> sum;
const compute_t rn = 1.f / static_cast<compute_t>(params.cols);
// Load weights
Cvec gamma[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
// fp8 factors
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Wvec gamma_in;
gamma_in.load_from_elts(params.gamma, col, params.cols - col);
gamma_in.to(gamma[it]);
}
// fp8 factors
compute_t scale;
if (params.fp8_out) {
scale = *reinterpret_cast<compute_t *>(params.scale);
Ivec x_in;
x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col);
x_in.to(x[it]);
}
compute_t amax = 0;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
// Load input
Cvec x[LDGS];
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec x_in;
x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col);
x_in.to(x[it]);
}
// Compute variance
compute_t sqsigma = 0.f;
// Compute variance
compute_t sqsigma = 0.f;
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t diff = x[it].data.elt[jt];
sqsigma += diff * diff;
}
}
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t diff = x[it].data.elt[jt];
sqsigma += diff * diff;
}
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
}
}
sqsigma = reducer.allreduce(sqsigma, sum) * rn;
compute_t rs = rsqrtf(sqsigma + params.epsilon);
// Write statistics
if (gidn == 0 && row < params.rows) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
rs_ptr[row] = rs;
}
// Write statistics
if (gidn == 0 && row < params.rows) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
rs_ptr[row] = rs;
}
// Compute output
#pragma unroll
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
// Compute output values
Cvec z;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (x[it].data.elt[jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
z.data.elt[jt] = g_ij * y_ij;
}
// Apply fp8 factors
if (params.fp8_out) {
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
// Compute output values
Cvec z;
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
}
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col);
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t y_ij = rs * (x[it].data.elt[jt]);
compute_t g_ij = gamma[it].data.elt[jt];
if (params.zero_centered_gamma) {
g_ij += 1;
}
}
z.data.elt[jt] = g_ij * y_ij;
}
// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
// Apply fp8 factors
if (params.fp8_out) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale;
}
}
}
// Store output
Ovec z_out;
z.to(z_out);
z_out.store_to_elts(params.z, row * params.cols + col, params.cols - col);
}
}
// Finalize fp8 factors
if (params.fp8_out) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
}
}
} // namespace rmsnorm
......
......@@ -5,80 +5,65 @@
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include "common.h"
namespace transformer_engine {
size_t typeToSize(const transformer_engine::DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size;
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size;); // NOLINT(*)
}
bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 ||
t == transformer_engine::DType::kFloat8E5M2;
return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2;
}
void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.data.dtype;
if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv
NVTE_CHECK(t.scale_inv.dptr != nullptr,
"FP8 input " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
} else {
NVTE_CHECK(t.scale.dptr == nullptr,
"Scale is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input " + name + ".");
}
NVTE_CHECK(t.data.dptr != nullptr,
"Input " + name + " is not allocated!");
NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!");
}
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.data.dtype;
if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, amax and scale_inv
NVTE_CHECK(t.amax.dptr != nullptr,
"FP8 output " + name + " must have amax tensor.");
NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor.");
NVTE_CHECK(t.amax.dtype == DType::kFloat32);
NVTE_CHECK(t.amax.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale_inv.dptr != nullptr,
"FP8 output " + name + " must have scale.");
NVTE_CHECK(t.amax.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale.dptr != nullptr,
"FP8 output " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale.dtype == DType::kFloat32);
NVTE_CHECK(t.scale.shape == std::vector<size_t>{ 1 });
NVTE_CHECK(t.scale.shape == std::vector<size_t>{1});
} else {
NVTE_CHECK(t.scale.dptr == nullptr,
"Scale is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 output " + name + ".");
}
if (!allow_empty) {
NVTE_CHECK(t.data.dptr != nullptr,
"Output " + name + " is not allocated!");
NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!");
}
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(void *dptr,
const NVTEShape shape,
const NVTEDType dtype,
float *amax,
float *scale,
float *scale_inv) {
NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax,
float *scale, float *scale_inv) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->data.dptr = dptr;
ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
......@@ -97,11 +82,11 @@ void nvte_destroy_tensor(NVTETensor tensor) {
NVTEDType nvte_tensor_type(const NVTETensor tensor) {
return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor*>(tensor)->data.dtype);
reinterpret_cast<const transformer_engine::Tensor *>(tensor)->data.dtype);
}
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret;
ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size();
......@@ -109,40 +94,40 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
}
void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr;
}
float *nvte_tensor_amax(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!");
return reinterpret_cast<float*>(t.amax.dptr);
return reinterpret_cast<float *>(t.amax.dptr);
}
float *nvte_tensor_scale(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale.dptr);
return reinterpret_cast<float *>(t.scale.dptr);
}
float *nvte_tensor_scale_inv(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
"Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr);
return reinterpret_cast<float *>(t.scale_inv.dptr);
}
void nvte_tensor_pack_create(NVTETensorPack* pack) {
void nvte_tensor_pack_create(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
}
}
void nvte_tensor_pack_destroy(NVTETensorPack* pack) {
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor*>(pack->tensors[i]);
delete t;
auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
delete t;
}
}
......@@ -4,13 +4,12 @@
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h>
#include <algorithm>
#include <cuda_runtime.h>
#include "../common.h"
#include "../util/rtc.h"
#include "../util/string.h"
......@@ -49,26 +48,18 @@ struct KernelConfig {
/* Elements per L1 cache store to transposed output */
size_t elements_per_store_t = 0;
KernelConfig(size_t row_length,
size_t num_rows,
size_t itype_size,
size_t otype_size,
size_t load_size_,
size_t store_size_)
: load_size{load_size_}
, store_size{store_size_} {
KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t otype_size,
size_t load_size_, size_t store_size_)
: load_size{load_size_}, store_size{store_size_} {
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0
|| store_size % otype_size != 0
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
if (load_size % itype_size != 0 || store_size % otype_size != 0 ||
cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) {
return;
}
const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size;
valid = (row_length % row_tile_elements == 0
&& num_rows % col_tile_elements == 0);
valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
if (!valid) {
return;
}
......@@ -80,12 +71,9 @@ struct KernelConfig {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size)
/ itype_size);
elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size)
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size)
/ otype_size);
elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) / itype_size);
elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) / otype_size);
elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) / otype_size);
}
/* Compare by estimated cost */
......@@ -104,8 +92,8 @@ struct KernelConfig {
const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count;
const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2;
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1) / p1;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2) / p2;
const auto cost1 = (scale / l1 + scale / sc1 + scale / st1) / p1;
const auto cost2 = (scale / l2 + scale / sc2 + scale / st2) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
......@@ -114,16 +102,14 @@ struct KernelConfig {
};
template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void
__launch_bounds__(block_size)
cast_transpose_general_kernel(const IType * __restrict__ const input,
const CType * __restrict__ const noop,
OType * __restrict__ const output_c,
OType * __restrict__ const output_t,
const CType * __restrict__ const scale_ptr,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows) {
__global__ void __launch_bounds__(block_size)
cast_transpose_general_kernel(const IType *__restrict__ const input,
const CType *__restrict__ const noop,
OType *__restrict__ const output_c,
OType *__restrict__ const output_t,
const CType *__restrict__ const scale_ptr,
CType *__restrict__ const amax_ptr, const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
......@@ -165,16 +151,16 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
if (row < num_rows) {
#pragma unroll
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) {
const CType in = input[row * row_length + col + j2];
......@@ -190,24 +176,24 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
}
// Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
if (col < row_length) {
#pragma unroll
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
......@@ -229,18 +215,15 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
} // namespace
void cast_transpose(const Tensor &input,
const Tensor &noop,
Tensor *cast_output_,
Tensor *transposed_output_,
cudaStream_t stream) {
void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_,
Tensor *transposed_output_, cudaStream_t stream) {
Tensor &cast_output = *cast_output_;
Tensor &transposed_output = *transposed_output_;
// Check no-op flag
if (noop.data.dptr != nullptr) {
size_t numel = 1;
for (const auto& dim : noop.data.shape) {
for (const auto &dim : noop.data.shape) {
numel *= dim;
}
NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, ".");
......@@ -254,16 +237,14 @@ void cast_transpose(const Tensor &input,
CheckOutputTensor(transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions.");
NVTE_CHECK(transposed_output.data.shape.size() == 2,
"Transposed output must have 2 dimensions.");
NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output.");
NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output.");
NVTE_CHECK(transposed_output.data.shape[0] == row_length,
"Wrong dimension of transposed output.");
NVTE_CHECK(transposed_output.data.shape[1] == num_rows,
"Wrong dimension of transposed output.");
NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output.");
// Check tensor pointers
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
......@@ -276,118 +257,111 @@ void cast_transpose(const Tensor &input,
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output.data.dtype, OutputType,
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType);
constexpr size_t otype_size = sizeof(OutputType);
// Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0
&& num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows,
itype_size, otype_size,
load_size, store_size);
};
add_config(8, 8);
add_config(4, 8); add_config(8, 4);
add_config(4, 4);
add_config(2, 8); add_config(8, 2);
add_config(2, 4); add_config(4, 2);
add_config(2, 2);
add_config(1, 8); add_config(8, 1);
add_config(1, 4); add_config(4, 1);
add_config(1, 2); add_config(2, 1);
add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(),
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings("cast_transpose"
",itype=", itype_name,
",otype=", otype_name,
",load_size=", load_size,
",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_cu;
code = regex_replace(code, "__ITYPE__", itype_name);
code = regex_replace(code, "__OTYPE__", otype_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", block_size);
rtc_manager.compile(kernel_label,
"cast_transpose_optimized_kernel",
code,
"transformer_engine/common/transpose/rtc/cast_transpose.cu");
}
rtc_manager.launch(kernel_label,
num_blocks, block_size, 0, stream,
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType*>(cast_output.data.dptr),
static_cast<OutputType*>(transposed_output.data.dptr),
static_cast<const CType*>(cast_output.scale.dptr),
static_cast<CType*>(cast_output.amax.dptr),
row_length, num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
const int num_blocks = (DIVUP(row_length, row_tile_size)
* DIVUP(num_rows, col_tile_size));
cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
<<<num_blocks, block_size, 0, stream>>>(
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr),
row_length, num_rows);
}
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output.data.dtype, OutputType,
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType);
constexpr size_t otype_size = sizeof(OutputType);
// Choose between runtime-compiled or statically-compiled kernel
const bool aligned =
(row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size,
store_size);
};
add_config(8, 8);
add_config(4, 8);
add_config(8, 4);
add_config(4, 4);
add_config(2, 8);
add_config(8, 2);
add_config(2, 4);
add_config(4, 2);
add_config(2, 2);
add_config(1, 8);
add_config(8, 1);
add_config(1, 4);
add_config(4, 1);
add_config(1, 2);
add_config(2, 1);
add_config(1, 1);
const auto &kernel_config =
*std::min_element(kernel_configs.begin(), kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch
auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings(
"cast_transpose"
",itype=",
itype_name, ",otype=", otype_name, ",load_size=", load_size,
",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_cu;
code = regex_replace(code, "__ITYPE__", itype_name);
code = regex_replace(code, "__OTYPE__", otype_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", block_size);
rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code,
"transformer_engine/common/transpose/rtc/cast_transpose.cu");
}
rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
const int num_blocks =
(DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
<<<num_blocks, block_size, 0, stream>>>(
static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
}); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_cast_transpose(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine;
auto noop = Tensor();
cast_transpose(*reinterpret_cast<const Tensor*>(input),
noop,
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor *>(transposed_output), stream);
}
void nvte_cast_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(noop),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
cast_transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor *>(transposed_output), stream);
}
......@@ -4,16 +4,18 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <iostream>
#include <type_traits>
#include "../utils.cuh"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../common.h"
#include "../util/math.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
namespace transformer_engine {
......@@ -26,7 +28,7 @@ namespace {
constexpr size_t n_warps_per_tile = 8;
constexpr size_t desired_load_size = 8;
constexpr size_t desired_store_size = 8;
constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers
constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers
constexpr size_t desired_store_size_dact = 4;
constexpr size_t threads_per_warp = static_cast<size_t>(THREADS_PER_WARP);
......@@ -38,443 +40,411 @@ static_assert(cast_transpose_num_threads <= max_threads_per_block);
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
size_t load_size = 0; // Vector load size
size_t store_size = 0; // Vector store size to transposed output
bool valid = false; // Whether config is valid
bool is_dact = false; // Whether dact is used
size_t num_blocks = 0; // Number of CUDA blocks
size_t active_sm_count = 0; // Number of active SMs
size_t elements_per_load = 0; // Elements per L1 cache load
size_t elements_per_load_dact = 0; // Elements per L1 cache load dact
size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output
size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output
KernelConfig(size_t row_length,
size_t num_rows,
size_t itype_size,
size_t itype2_size,
size_t otype_size,
size_t load_size_,
size_t store_size_,
bool is_dact_)
: load_size{load_size_}
, store_size{store_size_}
, is_dact{is_dact_} {
if (is_dact) {
if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) {
return;
}
}
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0
|| store_size % otype_size != 0
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
return;
}
/* row_tile_elements */
const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size;
/* col_tile_elements */
const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size;
const size_t num_tiles_x = row_length / tile_size_x;
const size_t num_tiles_y = num_rows / tile_size_y;
valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0);
if (!valid) {
return;
}
size_t load_size = 0; // Vector load size
size_t store_size = 0; // Vector store size to transposed output
bool valid = false; // Whether config is valid
bool is_dact = false; // Whether dact is used
size_t num_blocks = 0; // Number of CUDA blocks
size_t active_sm_count = 0; // Number of active SMs
size_t elements_per_load = 0; // Elements per L1 cache load
size_t elements_per_load_dact = 0; // Elements per L1 cache load dact
size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output
size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output
KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t itype2_size,
size_t otype_size, size_t load_size_, size_t store_size_, bool is_dact_)
: load_size{load_size_}, store_size{store_size_}, is_dact{is_dact_} {
if (is_dact) {
if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) {
return;
}
}
// Number of CUDA blocks
num_blocks = num_tiles_x * num_tiles_y;
// Parameters for performance model
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size)
/ itype_size);
elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size)
/ itype2_size);
elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size)
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size)
/ otype_size);
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0 || store_size % otype_size != 0 ||
cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) {
return;
}
/* row_tile_elements */
const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size;
/* col_tile_elements */
const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size;
const size_t num_tiles_x = row_length / tile_size_x;
const size_t num_tiles_y = num_rows / tile_size_y;
valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0);
if (!valid) {
return;
}
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/elements_per_load
// + 1/elements_per_load_dact
// + 1/elements_per_store_c
// + 1/elements_per_store_t) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->elements_per_load;
const auto &la1 = this->elements_per_load_dact;
const auto &sc1 = this->elements_per_store_c;
const auto &st1 = this->elements_per_store_t;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.elements_per_load;
const auto &la2 = other.elements_per_load_dact;
const auto &sc2 = other.elements_per_store_c;
const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count;
const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1);
const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1);
const auto scale = scale1 * scale2;
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1 + (is_dact ? (scale / la1) : 0))
/ p1;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2 + (is_dact ? (scale / la2) : 0))
/ p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
// Number of CUDA blocks
num_blocks = num_tiles_x * num_tiles_y;
// Parameters for performance model
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) / itype_size);
elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) / itype2_size);
elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size) / otype_size);
elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size) / otype_size);
}
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/elements_per_load
// + 1/elements_per_load_dact
// + 1/elements_per_store_c
// + 1/elements_per_store_t) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->elements_per_load;
const auto &la1 = this->elements_per_load_dact;
const auto &sc1 = this->elements_per_store_c;
const auto &st1 = this->elements_per_store_t;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.elements_per_load;
const auto &la2 = other.elements_per_load_dact;
const auto &sc2 = other.elements_per_store_c;
const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count;
const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1);
const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1);
const auto scale = scale1 * scale2;
const auto cost1 =
(scale / l1 + scale / sc1 + scale / st1 + (is_dact ? (scale / la1) : 0)) / p1;
const auto cost2 =
(scale / l2 + scale / sc2 + scale / st2 + (is_dact ? (scale / la2) : 0)) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
}
};
template <bool IS_DBIAS, bool IS_FULL_TILE, int nvec_in, int nvec_out,
typename OVec, typename CVec, typename CType>
template <bool IS_DBIAS, bool IS_FULL_TILE, int nvec_in, int nvec_out, typename OVec, typename CVec,
typename CType>
inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
const size_t current_place, const size_t stride,
const CType scale,
CType &amax, // NOLINT(*)
CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane,
const bool valid_store) {
using OType = typename OVec::type;
using OVecC = Vec<OType, nvec_in>;
CVec step_dbias;
if constexpr (IS_DBIAS) {
step_dbias.clear();
using OType = typename OVec::type;
using OVecC = Vec<OType, nvec_in>;
CVec step_dbias;
if constexpr (IS_DBIAS) {
step_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
}
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
}
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
}
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
}
if (IS_FULL_TILE || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
if (IS_FULL_TILE || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
if constexpr (IS_DBIAS) {
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
out_dbias.data.elt[j] += elt;
}
}
if constexpr (IS_DBIAS) {
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
out_dbias.data.elt[j] += elt;
}
}
}
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace,
const int nvec_out) {
const size_t row_length = cast_output.data.shape[1];
const size_t num_rows = cast_output.data.shape[0];
Tensor *workspace, const int nvec_out) {
const size_t row_length = cast_output.data.shape[1];
const size_t num_rows = cast_output.data.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->data.dtype = DType::kFloat32;
workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->data.dtype = DType::kFloat32;
}
template<int nvec, typename ComputeType, typename OutputType>
__global__ void
__launch_bounds__(reduce_dbias_num_threads)
reduce_dbias_kernel(OutputType* const dbias_output,
const ComputeType* const dbias_partial,
const int row_length,
const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>;
template <int nvec, typename ComputeType, typename OutputType>
__global__ void __launch_bounds__(reduce_dbias_num_threads)
reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial,
const int row_length, const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= row_length) {
return;
}
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec;
if (thread_id * nvec >= row_length) {
return;
}
const int stride_in_vec = row_length / nvec;
const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec;
OutputType *const thread_out_base = dbias_output + thread_id * nvec;
ComputeVec ldg_vec;
ComputeVec acc_vec; acc_vec.clear();
for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
}
const int stride_in_vec = row_length / nvec;
OutputVec stg_vec;
#pragma unroll
ComputeVec ldg_vec;
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
acc_vec.data.elt[e] += ldg_vec.data.elt[e];
}
stg_vec.store_to(thread_out_base, 0);
}
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
}
stg_vec.store_to(thread_out_base, 0);
}
template <typename InputType>
void reduce_dbias(const Tensor &workspace,
Tensor *dbias,
const size_t row_length,
const size_t num_rows,
const int nvec_out,
cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType);
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows = DIVUP(num_rows,
static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks = DIVUP(row_length,
reduce_dbias_num_threads * reduce_dbias_nvec);
using DbiasOutputType = fp32;
reduce_dbias_kernel<reduce_dbias_nvec, DbiasOutputType, InputType>
<<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>
(reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr),
reduce_dbias_row_length,
reduce_dbias_num_rows);
void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length,
const size_t num_rows, const int nvec_out, cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType);
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows =
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks =
DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec);
using DbiasOutputType = fp32;
reduce_dbias_kernel<reduce_dbias_nvec, DbiasOutputType, InputType>
<<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>(
reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_num_rows);
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param,
int nvec_in, int nvec_out, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_fused_kernel_notaligned(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using IType2 = typename Param::InputType2;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using IVec2 = Vec<IType2, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1)
/ (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile)
+ warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out)
* THREADS_PER_WARP;
const size_t tile_offset_transp = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in)
* THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset;
const IType2 * const my_act_input_tile = param.act_input + tile_offset;
OType * const my_output_c_tile = param.output_c + tile_offset;
OType * const my_output_t_tile = param.output_t + tile_offset_transp;
CType * const my_partial_dbias_tile = param.workspace
+ (tile_id_x * (nvec_in * THREADS_PER_WARP)
+ tile_id_y * row_length);
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch)
+ (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP)
* (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
IVec2 act_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
CType amax = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
const size_t ld_offset = current_stride + my_place + stride * i;
in[0][i].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, ld_offset);
}
} else {
in[0][i].clear();
if constexpr (IS_DACT) {
act_in[0][i].clear();
}
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
} else {
in[current_in][j].clear();
if constexpr (IS_DACT) {
act_in[current_in][j].clear();
}
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param, int nvec_in,
int nvec_out, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
__global__ void __launch_bounds__(cast_transpose_num_threads)
cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length,
const size_t num_rows, const size_t num_tiles) {
using IType = typename Param::InputType;
using IType2 = typename Param::InputType2;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using IVec2 = Vec<IType2, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x =
(row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id =
blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset =
(tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
const size_t tile_offset_transp =
(tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
const IType *const my_input_tile = param.input + tile_offset;
const IType2 *const my_act_input_tile = param.act_input + tile_offset;
OType *const my_output_c_tile = param.output_c + tile_offset;
OType *const my_output_t_tile = param.output_t + tile_offset_transp;
CType *const my_partial_dbias_tile =
param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length =
row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height =
row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
IVec2 act_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
CType amax = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
{
const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
const size_t ld_offset = current_stride + my_place + stride * i;
in[0][i].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, ld_offset);
}
CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
if constexpr (IS_DACT) {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k])
* OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
}
}
} else {
in[0][i].clear();
if constexpr (IS_DACT) {
act_in[0][i].clear();
}
const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
constexpr bool IS_FULL_TILE = false;
const bool valid_store = (my_place < tile_length)
&& (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_space[i], partial_dbias, my_output_c_tile, current_place,
stride, scale, amax, dbias_shfl_src_lane, valid_store);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_row += nvec_out;
}
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP] = out_space[j][i];
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load =
my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
} else {
in[current_in][j].clear();
if constexpr (IS_DACT) {
act_in[current_in][j].clear();
}
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
current_stride = i * output_stride
+ warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
if constexpr (IS_DACT) {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
}
__syncthreads();
}
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
if (my_id_in_warp < tile_length) {
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
const int dbias_shfl_src_lane =
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
constexpr bool IS_FULL_TILE = false;
const bool valid_store =
(my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_space[i], partial_dbias,
my_output_c_tile, current_place, stride, scale,
amax, dbias_shfl_src_lane, valid_store);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_row += nvec_out;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
if (my_id_in_warp < tile_length) {
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
}
}
/* warp tile amax reduce*/
amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id);
/* warp tile amax reduce*/
amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, amax);
}
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, amax);
}
}
}
static const char* ActTypeToString[] = {
static const char *ActTypeToString[] = {
"NoAct", // 0
"Sigmoid", // 1
"GeLU", // 2
......@@ -484,1021 +454,915 @@ static const char* ActTypeToString[] = {
"SReLU" // 6
};
template <typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
template <typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
int get_dactivation_type() {
if (OP == &sigmoid<ComputeType, ComputeType>) {
return 1;
} else if (OP == &dgelu<ComputeType, ComputeType>) {
return 2;
} else if (OP == &dqgelu<ComputeType, ComputeType>) {
return 3;
} else if (OP == &dsilu<ComputeType, ComputeType>) {
return 4;
} else if (OP == &drelu<ComputeType, ComputeType>) {
return 5;
} else if (OP == &dsrelu<ComputeType, ComputeType>) {
return 6;
} else {
return 0;
}
if (OP == &sigmoid<ComputeType, ComputeType>) {
return 1;
} else if (OP == &dgelu<ComputeType, ComputeType>) {
return 2;
} else if (OP == &dqgelu<ComputeType, ComputeType>) {
return 3;
} else if (OP == &dsilu<ComputeType, ComputeType>) {
return 4;
} else if (OP == &drelu<ComputeType, ComputeType>) {
return 5;
} else if (OP == &dsrelu<ComputeType, ComputeType>) {
return 6;
} else {
return 0;
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
void cast_transpose_fused(const Tensor &input,
const Tensor &act_input,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *dbias,
Tensor *workspace,
ComputeType (*OP)(ComputeType, const ParamOP &)>
void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output,
Tensor *transposed_output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_fused_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
if constexpr (IS_DBIAS) {
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(dbias->data.dtype == input.data.dtype,
"DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length },
"Wrong shape of DBias.");
}
if constexpr (IS_DACT) {
CheckInputTensor(act_input, "act_input");
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>;
constexpr int itype_size = sizeof(InputType);
constexpr int itype2_size = sizeof(InputType2);
constexpr int otype_size = sizeof(OutputType);
const bool aligned = (row_length % THREADS_PER_WARP == 0)
&& (num_rows % THREADS_PER_WARP == 0);
const bool jit_compiled = aligned && rtc::is_enabled();
size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size);
size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size);
size_t num_blocks;
if (jit_compiled) {
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size_config, size_t store_size_config) {
kernel_configs.emplace_back(row_length, num_rows,
itype_size, itype2_size, otype_size,
load_size_config, store_size_config,
IS_DACT);
};
add_config(8, 8);
add_config(4, 8); add_config(8, 4);
add_config(4, 4);
add_config(2, 8); add_config(8, 2);
add_config(2, 4); add_config(4, 2);
add_config(2, 2);
add_config(1, 8); add_config(8, 1);
add_config(1, 4); add_config(4, 1);
add_config(1, 2); add_config(2, 1);
add_config(1, 1);
// Select the kernel configuration with the lowest cost
const auto &kernel_config = *std::min_element(kernel_configs.begin(),
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
load_size = kernel_config.load_size;
store_size = kernel_config.store_size;
num_blocks = kernel_config.num_blocks;
CheckInputTensor(input, "cast_transpose_fused_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
if constexpr (IS_DBIAS) {
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{row_length}, "Wrong shape of DBias.");
}
if constexpr (IS_DACT) {
CheckInputTensor(act_input, "act_input");
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output->data.dtype, OutputType, using InputType2 = InputType;
using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>;
constexpr int itype_size = sizeof(InputType);
constexpr int itype2_size = sizeof(InputType2);
constexpr int otype_size = sizeof(OutputType);
const bool aligned =
(row_length % THREADS_PER_WARP == 0) && (num_rows % THREADS_PER_WARP == 0);
const bool jit_compiled = aligned && rtc::is_enabled();
size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size);
size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size);
size_t num_blocks;
if (jit_compiled) {
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size_config, size_t store_size_config) {
kernel_configs.emplace_back(row_length, num_rows, itype_size, itype2_size, otype_size,
load_size_config, store_size_config, IS_DACT);
};
add_config(8, 8);
add_config(4, 8);
add_config(8, 4);
add_config(4, 4);
add_config(2, 8);
add_config(8, 2);
add_config(2, 4);
add_config(4, 2);
add_config(2, 2);
add_config(1, 8);
add_config(8, 1);
add_config(1, 4);
add_config(4, 1);
add_config(1, 2);
add_config(2, 1);
add_config(1, 1);
// Select the kernel configuration with the lowest cost
const auto &kernel_config =
*std::min_element(kernel_configs.begin(), kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
load_size = kernel_config.load_size;
store_size = kernel_config.store_size;
num_blocks = kernel_config.num_blocks;
}
const size_t nvec_in = load_size / itype_size;
const size_t nvec_out = store_size / otype_size;
const size_t tile_size_x = nvec_in * threads_per_warp;
const size_t tile_size_y = nvec_out * threads_per_warp;
const size_t num_tiles_x = DIVUP(row_length, tile_size_x);
const size_t num_tiles_y = DIVUP(num_rows, tile_size_y);
const size_t num_tiles = num_tiles_x * num_tiles_y;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
if (!jit_compiled) {
num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block);
} if constexpr (IS_DBIAS) {
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
const size_t nvec_in = load_size / itype_size;
const size_t nvec_out = store_size / otype_size;
const size_t tile_size_x = nvec_in * threads_per_warp;
const size_t tile_size_y = nvec_out * threads_per_warp;
const size_t num_tiles_x = DIVUP(row_length, tile_size_x);
const size_t num_tiles_y = DIVUP(num_rows, tile_size_y);
const size_t num_tiles = num_tiles_x * num_tiles_y;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
if (!jit_compiled) {
num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block);
}
if constexpr (IS_DBIAS) {
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output,
workspace, nvec_out);
return;
}
}
size_t VecOutputTypeSize;
switch (nvec_out) {
case 1: VecOutputTypeSize = sizeof(Vec<OutputType, 1>); break;
case 2: VecOutputTypeSize = sizeof(Vec<OutputType, 2>); break;
case 4: VecOutputTypeSize = sizeof(Vec<OutputType, 4>); break;
case 8: VecOutputTypeSize = sizeof(Vec<OutputType, 8>); break;
}
size_t VecOutputTypeSize;
switch (nvec_out) {
case 1:
VecOutputTypeSize = sizeof(Vec<OutputType, 1>);
break;
case 2:
VecOutputTypeSize = sizeof(Vec<OutputType, 2>);
break;
case 4:
VecOutputTypeSize = sizeof(Vec<OutputType, 4>);
break;
case 8:
VecOutputTypeSize = sizeof(Vec<OutputType, 8>);
break;
} size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(threads_per_warp + 1) * VecOutputTypeSize;
if constexpr (IS_DBIAS) {
size_t VecComputeTypeSize;
switch (nvec_in) {
case 1:
VecComputeTypeSize = sizeof(Vec<ComputeType, 1>);
break;
case 2:
VecComputeTypeSize = sizeof(Vec<ComputeType, 2>);
break;
case 4:
VecComputeTypeSize = sizeof(Vec<ComputeType, 4>);
break;
case 8:
VecComputeTypeSize = sizeof(Vec<ComputeType, 8>);
break;
}
size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile
* (threads_per_warp + 1) * VecOutputTypeSize;
if constexpr (IS_DBIAS) {
size_t VecComputeTypeSize;
switch (nvec_in) {
case 1: VecComputeTypeSize = sizeof(Vec<ComputeType, 1>); break;
case 2: VecComputeTypeSize = sizeof(Vec<ComputeType, 2>); break;
case 4: VecComputeTypeSize = sizeof(Vec<ComputeType, 4>); break;
case 8: VecComputeTypeSize = sizeof(Vec<ComputeType, 8>); break;
}
const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize;
if (shared_size_transpose < shared_size_dbias) {
shared_size_transpose = shared_size_dbias;
}
}
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(transposed_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(transposed_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize;
if (shared_size_transpose < shared_size_dbias) {
shared_size_transpose = shared_size_dbias;
}
}
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(transposed_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(transposed_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
} if constexpr (IS_DACT) {
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
}
// Runtime-compiled tuned kernel
if (jit_compiled) {
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *itype2_name = TypeInfo<InputType2>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
int dActType = 0;
if constexpr (IS_DACT) {
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
dActType = get_dactivation_type<ComputeType, ParamOP, OP>();
}
// Runtime-compiled tuned kernel
if (jit_compiled) {
constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *itype2_name = TypeInfo<InputType2>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name;
int dActType = 0;
if constexpr (IS_DACT) {
dActType = get_dactivation_type<ComputeType, ParamOP, OP>();
}
// Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label =
concat_strings("cast_transpose_fusion"
",itype=", itype_name,
",itype2=", itype2_name,
",otype=", otype_name,
",load_size=", load_size,
",store_size=", store_size,
",IS_DBIAS=", IS_DBIAS,
",IS_DACT=", IS_DACT,
",dactivationType=", ActTypeToString[dActType]);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu;
code = regex_replace(code, "__ITYPE__", itype_name);
code = regex_replace(code, "__ITYPE2__", itype2_name);
code = regex_replace(code, "__OTYPE__", otype_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads);
code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS);
code = regex_replace(code, "__IS_DACT__", IS_DACT);
code = regex_replace(code, "__DACTIVATION_TYPE__", dActType);
rtc_manager.compile(
kernel_label,
"cast_transpose_fusion_kernel_optimized",
code,
"transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu");
}
rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED);
rtc_manager.launch(kernel_label,
num_blocks, cast_transpose_num_threads, shared_size_transpose, stream,
param, row_length, num_rows, num_tiles);
} else { // Statically-compiled general kernel
constexpr size_t load_size = IS_DACT ? desired_load_size_dact :
desired_load_size;
constexpr size_t store_size = IS_DACT ? desired_store_size_dact :
desired_store_size;
constexpr size_t nvec_in = load_size / itype_size;
constexpr size_t nvec_out = store_size / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>
(param, row_length, num_rows, num_tiles);
// Compile NVRTC kernel if needed and launch
auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings(
"cast_transpose_fusion"
",itype=",
itype_name, ",itype2=", itype2_name, ",otype=", otype_name,
",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS,
",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu;
code = regex_replace(code, "__ITYPE__", itype_name);
code = regex_replace(code, "__ITYPE2__", itype2_name);
code = regex_replace(code, "__OTYPE__", otype_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads);
code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS);
code = regex_replace(code, "__IS_DACT__", IS_DACT);
code = regex_replace(code, "__DACTIVATION_TYPE__", dActType);
rtc_manager.compile(
kernel_label, "cast_transpose_fusion_kernel_optimized", code,
"transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu");
}
if constexpr (IS_DBIAS) {
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
}
); // NOLINT(*)
); // NOLINT(*)
rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED);
rtc_manager.launch(kernel_label, num_blocks, cast_transpose_num_threads,
shared_size_transpose, stream, param, row_length, num_rows,
num_tiles);
} else { // Statically-compiled general kernel
constexpr size_t load_size = IS_DACT ? desired_load_size_dact : desired_load_size;
constexpr size_t store_size = IS_DACT ? desired_store_size_dact : desired_store_size;
constexpr size_t nvec_in = load_size / itype_size;
constexpr size_t nvec_out = store_size / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param,
nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in,
nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, num_tiles);
}
if constexpr (IS_DBIAS) {
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
}); // NOLINT(*)
); // NOLINT(*)
}
template <int nvec_in, int nvec_out,
typename CType, typename IType, typename OType, typename ParamOP,
CType (*OP1)(CType, const ParamOP&),
CType (*OP2)(CType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgated_act_cast_transpose_kernel(const IType * const input,
const IType * const act_input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile)
+ warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) {
return;
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType,
typename ParamOP, CType (*OP1)(CType, const ParamOP &),
CType (*OP2)(CType, const ParamOP &)>
__global__ void __launch_bounds__(cast_transpose_num_threads)
dgated_act_cast_transpose_kernel(const IType *const input, const IType *const act_input,
OType *const output_c, OType *const output_t,
const CType *const scale_ptr, CType *const amax,
CType *const scale_inv, const size_t row_length,
const size_t num_rows, const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id =
blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType *const my_input_tile =
input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
const IType *const my_act_input_tile =
act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
const IType *const my_gate_input_tile =
act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
row_length;
OType *const my_output_c_tile_0 =
output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
OType *const my_output_c_tile_1 =
output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
row_length;
OType *const my_output_t_tile_0 =
output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
OType *const my_output_t_tile_1 =
output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP +
row_length * num_rows;
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
const size_t stride2 = 2 * row_length / nvec_in;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
act_in[current_in][j].load_from(my_act_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = true;
constexpr bool valid_store = true;
constexpr int dbias_shfl_src_lane = 0;
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_trans_0, partial_dbias,
my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dgate, out_trans_1, partial_dbias,
my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space_0[j][i];
}
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
}
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP + row_length * num_rows;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
const size_t stride2 = 2 * row_length / nvec_in;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
#pragma unroll
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType,
typename ParamOP, CType (*OP1)(CType, const ParamOP &),
CType (*OP2)(CType, const ParamOP &)>
__global__ void __launch_bounds__(cast_transpose_num_threads)
dgated_act_cast_transpose_kernel_notaligned(const IType *const input,
const IType *const act_input, OType *const output_c,
OType *const output_t, const CType *const scale_ptr,
CType *const amax, CType *const scale_inv,
const size_t row_length, const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x =
(row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id =
blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType *const my_input_tile =
input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
const IType *const my_act_input_tile =
act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
const IType *const my_gate_input_tile =
act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
row_length;
OType *const my_output_c_tile_0 =
output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
OType *const my_output_c_tile_1 =
output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
row_length;
OType *const my_output_t_tile_0 =
output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
OType *const my_output_t_tile_1 =
output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP +
row_length * num_rows;
const size_t stride = row_length / nvec_in;
const size_t stride2 = 2 * row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length =
row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height =
row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
{
const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
} else {
in[0][i].clear();
act_in[0][i].clear();
gate_in[0][i].clear();
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
act_in[current_in][j].load_from(my_act_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
{
const bool valid_load =
my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = true;
constexpr bool valid_store = true;
constexpr int dbias_shfl_src_lane = 0;
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
act_in[current_in][j].load_from(
my_act_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(
my_gate_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j));
} else {
in[current_in][j].clear();
act_in[current_in][j].clear();
gate_in[current_in][j].clear();
}
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
}
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = false;
constexpr int dbias_shfl_src_lane = 0;
const bool valid_store =
(my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_trans_0, partial_dbias,
my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dgate, out_trans_1, partial_dbias,
my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
}
}
template <int nvec_in, int nvec_out,
typename CType, typename IType, typename OType,
typename ParamOP,
CType (*OP1)(CType, const ParamOP&),
CType (*OP2)(CType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
const IType * const act_input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP + row_length * num_rows;
const size_t stride = row_length / nvec_in;
const size_t stride2 = 2 * row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec act_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile,
current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile,
current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile,
current_stride2 + my_place + stride2 * i);
} else {
in[0][i].clear();
act_in[0][i].clear();
gate_in[0][i].clear();
}
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space_0[j][i];
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
{
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
act_in[current_in][j].load_from(my_act_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
} else {
in[current_in][j].clear();
act_in[current_in][j].clear();
gate_in[current_in][j].clear();
}
}
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP2(act_in[current_in ^ 1][j].data.elt[k], {});
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = false;
constexpr int dbias_shfl_src_lane = 0;
const bool valid_store = (my_place < tile_length)
&& (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) {
atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
}
}
template <typename ComputeType, typename ParamOP,
ComputeType (*OP1)(ComputeType, const ParamOP&),
ComputeType (*OP2)(ComputeType, const ParamOP&)>
void dgated_act_cast_transpose(const Tensor &input,
const Tensor &gated_act_input,
Tensor *cast_output,
Tensor *transposed_output,
cudaStream_t stream) {
CheckInputTensor(input, "dgated_act_cast_transpose_input");
CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input");
CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output");
CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
/* dact fusion kernel uses more registers */
constexpr int desired_load_size_dact = 4;
constexpr int desired_store_size_dact = 4;
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dact / itype_size;
constexpr int nvec_out = desired_store_size_dact / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles =
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if (full_tile) {
cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgated_act_cast_transpose_kernel
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel_notaligned
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgated_act_cast_transpose_kernel_notaligned
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
); // NOLINT(*)
template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType, const ParamOP &),
ComputeType (*OP2)(ComputeType, const ParamOP &)>
void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input,
Tensor *cast_output, Tensor *transposed_output,
cudaStream_t stream) {
CheckInputTensor(input, "dgated_act_cast_transpose_input");
CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input");
CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output");
CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output->data.dtype, OutputType, using InputType2 = InputType;
/* dact fusion kernel uses more registers */
constexpr int desired_load_size_dact = 4;
constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dact / itype_size;
constexpr int nvec_out = desired_store_size_dact / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles =
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if (full_tile) {
cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType,
Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
n_tiles);
} else {
cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
n_tiles);
}); // NOLINT(*)
); // NOLINT(*)
}
} // namespace
} // namespace
} // namespace transformer_engine
using ComputeType = typename transformer_engine::fp32;
void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(activation_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(activation_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
const NVTETensor act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
const NVTETensor silu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dsilu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(silu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dsilu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(silu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
const NVTETensor relu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &drelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(relu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_drelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &drelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(relu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
const NVTETensor srelu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(srelu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(srelu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
const NVTETensor qgelu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(qgelu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input,
NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu);
using namespace transformer_engine;
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(qgelu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
}
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dgelu<fp32, fp32>;
constexpr auto Activation = &gelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dgelu<fp32, fp32>;
constexpr auto Activation = &gelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
}
void nvte_dswiglu_cast_transpose(const NVTETensor input,
const NVTETensor swiglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dsilu<fp32, fp32>;
constexpr auto Activation = &silu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(swiglu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dsilu<fp32, fp32>;
constexpr auto Activation = &silu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(swiglu_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
}
void nvte_dreglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &drelu<fp32, fp32>;
constexpr auto Activation = &relu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
NVTE_API_CALL(nvte_dreglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &drelu<fp32, fp32>;
constexpr auto Activation = &relu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
}
void nvte_dsreglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
constexpr auto Activation = &srelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
NVTE_API_CALL(nvte_dsreglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
constexpr auto Activation = &srelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
}
void nvte_dqgeglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input,
NVTETensor cast_output,
NVTETensor transposed_output,
void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
constexpr auto Activation = &qgelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
NVTE_API_CALL(nvte_dqgeglu_cast_transpose);
using namespace transformer_engine;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
constexpr auto Activation = &qgelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
stream);
}
......@@ -4,13 +4,15 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <iostream>
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <iostream>
#include <vector>
#include "../utils.cuh"
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
......@@ -40,21 +42,14 @@ struct MultiCastTransposeArgs {
int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int block_range[kMaxTensorsPerKernel+1];
int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel
int num_tensors;
};
template <
int nvec_in,
int nvec_out,
bool aligned,
typename CType,
typename IType,
typename OType>
__global__ void
__launch_bounds__(threads_per_block)
multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
template <int nvec_in, int nvec_out, bool aligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(threads_per_block)
multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
using IVec = Vec<IType, nvec_in>;
using OVecC = Vec<OType, nvec_in>;
using OVecT = Vec<OType, nvec_out>;
......@@ -79,7 +74,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
// Find tensor corresponding to block
int tensor_id = 0;
while (args.block_range[tensor_id+1] <= bid) {
while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id;
}
const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]);
......@@ -104,11 +99,11 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
// type, and transposes in registers.
OVecT local_output_t[nvec_in][n_iterations];
CType local_amax = 0;
#pragma unroll
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
#pragma unroll
#pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) {
const int row = tile_row + i1 * nvec_out + i2;
const int col = tile_col + j1 * nvec_in;
......@@ -119,7 +114,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
} else {
local_input.clear();
if (row < num_rows) {
#pragma unroll
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
......@@ -127,7 +122,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
}
}
}
#pragma unroll
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
const CType x = CType(local_input.data.elt[j2]);
const OType y = OType(scale * x);
......@@ -140,7 +135,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
local_output_c.store_to(&output_c[row * row_length + col]);
} else {
if (row < num_rows) {
#pragma unroll
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) {
output_c[row * row_length + col + j2] = local_output_c.data.elt[j2];
......@@ -152,17 +147,17 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
}
// Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy;
const int j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
#pragma unroll
#pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidx;
const int j1 = tidy + iter * bdimy;
......@@ -172,7 +167,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
} else {
if (col < row_length) {
#pragma unroll
#pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
......@@ -196,8 +191,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
void multi_cast_transpose(const std::vector<Tensor*> input_list,
std::vector<Tensor*> cast_output_list,
std::vector<Tensor*> transposed_output_list,
cudaStream_t stream) {
std::vector<Tensor*> transposed_output_list, cudaStream_t stream) {
// Check that number of tensors is valid
NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match");
......@@ -218,15 +212,11 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id));
CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == itype,
"Input tensor types do not match.");
NVTE_CHECK(cast_output.data.dtype == otype,
"C output tensor types do not match.");
NVTE_CHECK(transposed_output.data.dtype == otype,
"T output tensor types do not match.");
NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match.");
NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match.");
NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2,
"Input tensor must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape == input.data.shape,
"C output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.data.shape.size() == 2,
......@@ -251,27 +241,28 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full
if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*)
kernel_args_aligned.num_tensors = 0;
}
if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks =
kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*)
kernel_args_unaligned.num_tensors = 0;
}
......@@ -283,8 +274,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
const int num_tiles = num_tiles_m * num_tiles_n;
// Figure out whether to use aligned or unaligned kernel
const bool aligned = ((num_tiles_m * tile_dim_m == num_rows)
&& (num_tiles_n * tile_dim_n == row_length));
const bool aligned =
((num_tiles_m * tile_dim_m == num_rows) && (num_tiles_n * tile_dim_n == row_length));
auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned;
// Add tensor to kernel argument struct
......@@ -296,53 +287,48 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++;
}
// Launch kernel
if (kernel_args_aligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*)
}
if (kernel_args_unaligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
itype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks =
kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*)
}
}
} // namespace transformer_engine
void nvte_multi_cast_transpose(size_t num_tensors,
const NVTETensor* input_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
NVTETensor* cast_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine;
std::vector<Tensor*> input_list_,
cast_output_list_, transposed_output_list_;
std::vector<Tensor*> input_list_, cast_output_list_, transposed_output_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i]));
transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i]));
}
multi_cast_transpose(input_list_,
cast_output_list_,
transposed_output_list_,
stream);
multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream);
}
......@@ -21,16 +21,11 @@ constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace
__global__ void
__launch_bounds__(block_size)
cast_transpose_optimized_kernel(const IType * __restrict__ const input,
const CType * __restrict__ const noop,
OType * __restrict__ const output_c,
OType * __restrict__ const output_t,
const CType * __restrict__ const scale_ptr,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows) {
__global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const IType* __restrict__ const input, const CType* __restrict__ const noop,
OType* __restrict__ const output_c, OType* __restrict__ const output_t,
const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr,
const size_t row_length, const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
......@@ -73,18 +68,18 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
IVec local_input;
OVecC local_output_c;
local_input.load_from(&input[row * row_length + col]);
#pragma unroll
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
const CType in = static_cast<CType>(local_input.data.elt[j2]);
const OType out = OType(in * scale);
......@@ -98,17 +93,17 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input,
}
// Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter];
}
__syncthreads();
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
......
......@@ -4,25 +4,25 @@
* See LICENSE for license information.
************************************************************************/
#include "utils.cuh"
#include "util/math.h"
#include "utils.cuh"
using namespace transformer_engine;
namespace {
// Parameters
using CType = float;
using IType = __ITYPE__;
using CType = float;
using IType = __ITYPE__;
using IType2 = __ITYPE2__;
using OType = __OTYPE__;
constexpr size_t LOAD_SIZE = __LOAD_SIZE__;
constexpr size_t STORE_SIZE = __STORE_SIZE__;
constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__;
constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__;
constexpr bool IS_DBIAS = __IS_DBIAS__;
constexpr bool IS_DACT = __IS_DACT__;
constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__;
using OType = __OTYPE__;
constexpr size_t LOAD_SIZE = __LOAD_SIZE__;
constexpr size_t STORE_SIZE = __STORE_SIZE__;
constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__;
constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__;
constexpr bool IS_DBIAS = __IS_DBIAS__;
constexpr bool IS_DACT = __IS_DACT__;
constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__;
constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType);
constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType);
......@@ -32,218 +32,209 @@ using IVec2 = Vec<IType2, NVEC_IN>;
using OVec = Vec<OType, NVEC_OUT>;
using Param = CTDBiasDActParam<IType, IType2, OType, CType>;
using OP = CType (*)(const CType, const Empty&);
using OP = CType (*)(const CType, const Empty &);
constexpr OP Activation[] = {
nullptr, // 0
&dsigmoid<CType, CType>, // 1
&dgelu<CType, CType>, // 2
&dqgelu<CType, CType>, // 3
&dsilu<CType, CType>, // 4
&drelu<CType, CType>, // 5
&dsrelu<CType, CType> // 6
nullptr, // 0
&dsigmoid<CType, CType>, // 1
&dgelu<CType, CType>, // 2
&dqgelu<CType, CType>, // 3
&dsilu<CType, CType>, // 4
&drelu<CType, CType>, // 5
&dsrelu<CType, CType> // 6
};
} // namespace
inline __device__ void
cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
OVec (&out_trans)[NVEC_IN],
CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
const CType scale,
CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane) {
using OVecC = Vec<OType, NVEC_IN>;
CVec step_dbias;
if constexpr (IS_DBIAS) {
step_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
}
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
}
out_cast.store_to(output_cast_tile, current_place + stride * i);
inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
OVec (&out_trans)[NVEC_IN],
CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride, const CType scale,
CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane) {
using OVecC = Vec<OType, NVEC_IN>;
CVec step_dbias;
if constexpr (IS_DBIAS) {
step_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
}
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
}
if constexpr (IS_DBIAS) {
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
out_dbias.data.elt[j] += elt;
}
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
if constexpr (IS_DBIAS) {
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
out_dbias.data.elt[j] += elt;
}
}
}
__global__ void
__launch_bounds__(BLOCK_SIZE)
cast_transpose_fusion_kernel_optimized(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE)
+ warp_id / WARPS_PER_TILE;
if (tile_id >= num_tiles) {
return;
__global__ void __launch_bounds__(BLOCK_SIZE)
cast_transpose_fusion_kernel_optimized(const Param param, const size_t row_length,
const size_t num_rows, const size_t num_tiles) {
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP);
const size_t tile_id =
blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) + warp_id / WARPS_PER_TILE;
if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset =
(tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) * THREADS_PER_WARP;
const size_t tile_offset_transp =
(tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) * THREADS_PER_WARP;
const IType *const my_input_tile = param.input + tile_offset;
const IType2 *const my_act_input_tile = param.act_input + tile_offset;
OType *const my_output_c_tile = param.output_c + tile_offset;
OType *const my_output_t_tile = param.output_t + tile_offset_transp;
CType *const my_partial_dbias_tile =
param.workspace + (tile_id_x * (NVEC_IN * THREADS_PER_WARP) + tile_id_y * row_length);
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][NVEC_OUT];
IVec2 act_in[2][NVEC_OUT];
const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE;
constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE;
OVec out_space[n_iterations][NVEC_IN];
const size_t stride = row_length / NVEC_IN;
const size_t output_stride = num_rows / NVEC_OUT;
size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
CType amax = 0.0f;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i);
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset = (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT)
* THREADS_PER_WARP;
const size_t tile_offset_transp = (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN)
* THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset;
const IType2 * const my_act_input_tile = param.act_input + tile_offset;
OType * const my_output_c_tile = param.output_c + tile_offset;
OType * const my_output_t_tile = param.output_t + tile_offset_transp;
CType * const my_partial_dbias_tile = param.workspace
+ (tile_id_x * (NVEC_IN * THREADS_PER_WARP)
+ tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch)
+ (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP)
* (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][NVEC_OUT];
IVec2 act_in[2][NVEC_OUT];
const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE;
constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE;
OVec out_space[n_iterations][NVEC_IN];
const size_t stride = row_length / NVEC_IN;
const size_t output_stride = num_rows / NVEC_OUT;
size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
CType amax = 0.0f;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i);
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
}
}
CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
#pragma unroll
for (unsigned int k = 0; k < NVEC_IN; ++k) {
if constexpr (IS_DACT) {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k])
* Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
} else {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
}
}
CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
#pragma unroll
for (unsigned int k = 0; k < NVEC_IN; ++k) {
if constexpr (IS_DACT) {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k]) *
Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
} else {
in_cast_fp32[j].data.elt[k] = static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
}
}
}
const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
const int dbias_shfl_src_lane =
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias,
my_output_c_tile, current_place,
stride, scale, amax, dbias_shfl_src_lane);
cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, my_output_c_tile,
current_place, stride, scale, amax, dbias_shfl_src_lane);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += NVEC_OUT * stride;
current_row += NVEC_OUT;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += NVEC_OUT * stride;
current_row += NVEC_OUT;
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_IN; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP
- j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
current_stride = i * output_stride
+ warp_id_in_tile * n_iterations * output_stride * NVEC_IN;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * NVEC_IN;
}
__syncthreads();
#pragma unroll
for (unsigned int i = 0; i < NVEC_IN; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space[j][i];
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
__syncthreads();
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * NVEC_IN;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * NVEC_IN;
}
__syncthreads();
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
}
// warp tile amax reduce
const CType max_block = reduce_max<BLOCK_SIZE/THREADS_PER_WARP>(amax, warp_id);
// warp tile amax reduce
const CType max_block = reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) {
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max_block);
}
if (threadIdx.x == 0) {
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max_block);
}
}
}
......@@ -19,13 +19,10 @@ constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace
__global__ void
__launch_bounds__(block_size)
transpose_optimized_kernel(const Type * __restrict__ const input,
const float * const noop,
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) {
__global__ void __launch_bounds__(block_size)
transpose_optimized_kernel(const Type* __restrict__ const input, const float* const noop,
Type* __restrict__ const output, const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
......@@ -63,17 +60,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles and transposes in
// registers.
OVec local_output[nvec_in][num_iterations];
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
IVec local_input;
local_input.load_from(&input[row * row_length + col]);
#pragma unroll
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
}
......@@ -81,17 +78,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input,
}
// Copy from registers to shared memory to global memory
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output[j1][i1] = local_output[j2][iter];
}
__syncthreads();
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
......
......@@ -4,13 +4,12 @@
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h>
#include <algorithm>
#include <cuda_runtime.h>
#include "../common.h"
#include "../util/rtc.h"
#include "../util/string.h"
......@@ -46,24 +45,18 @@ struct KernelConfig {
/* Elements per L1 cache store */
size_t elements_per_store = 0;
KernelConfig(size_t row_length,
size_t num_rows,
size_t type_size,
size_t load_size_,
KernelConfig(size_t row_length, size_t num_rows, size_t type_size, size_t load_size_,
size_t store_size_)
: load_size{load_size_}
, store_size{store_size_} {
: load_size{load_size_}, store_size{store_size_} {
// Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128;
if (load_size % type_size != 0
|| store_size % type_size != 0
|| cache_line_size % type_size != 0) {
if (load_size % type_size != 0 || store_size % type_size != 0 ||
cache_line_size % type_size != 0) {
return;
}
const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size;
valid = (row_length % row_tile_elements == 0
&& num_rows % col_tile_elements == 0);
valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
if (!valid) {
return;
}
......@@ -75,10 +68,8 @@ struct KernelConfig {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size)
/ type_size);
elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size)
/ type_size);
elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) / type_size);
elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) / type_size);
}
/* Compare by estimated cost */
......@@ -93,8 +84,8 @@ struct KernelConfig {
const auto &s2 = other.elements_per_store;
const auto &p2 = other.active_sm_count;
const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
const auto cost1 = (scale/l1 + scale/s1) / p1;
const auto cost2 = (scale/l2 + scale/s2) / p2;
const auto cost1 = (scale / l1 + scale / s1) / p1;
const auto cost2 = (scale / l2 + scale / s2) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
......@@ -103,13 +94,10 @@ struct KernelConfig {
};
template <size_t load_size, size_t store_size, typename Type>
__global__ void
__launch_bounds__(block_size)
transpose_general_kernel(const Type * __restrict__ const input,
const fp32 * const noop,
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) {
__global__ void __launch_bounds__(block_size)
transpose_general_kernel(const Type *__restrict__ const input, const fp32 *const noop,
Type *__restrict__ const output, const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes
......@@ -147,25 +135,25 @@ transpose_general_kernel(const Type * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles and transposes in
// registers.
OVec local_output[nvec_in][num_iterations];
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
#pragma unroll
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in;
IVec local_input;
local_input.clear();
if (row < num_rows) {
#pragma unroll
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
}
}
}
#pragma unroll
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
}
......@@ -173,24 +161,24 @@ transpose_general_kernel(const Type * __restrict__ const input,
}
// Copy transposed output from registers to global memory
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1];
#pragma unroll
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx;
shared_output[j1][i1] = local_output[j2][iter];
}
__syncthreads();
#pragma unroll
#pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2;
if (col < row_length) {
#pragma unroll
#pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) {
output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2];
......@@ -204,10 +192,7 @@ transpose_general_kernel(const Type * __restrict__ const input,
} // namespace
void transpose(const Tensor &input,
const Tensor &noop,
Tensor *output_,
cudaStream_t stream) {
void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) {
Tensor &output = *output_;
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions.");
......@@ -219,121 +204,106 @@ void transpose(const Tensor &input,
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == output.data.dtype,
"Input and output type must match.");
NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match.");
// Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t {
auto numel = [](const Tensor &tensor) -> size_t {
size_t acc = 1;
for (const auto& dim : tensor.data.shape) {
for (const auto &dim : tensor.data.shape) {
acc *= dim;
}
return acc;
};
if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1,
"Expected 1 element, ",
"but found ", numel(noop), ".");
NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type,
constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type);
// Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0
&& num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows, type_size,
load_size, store_size);
};
add_config(8, 8);
add_config(4, 8); add_config(8, 4);
add_config(4, 4);
add_config(2, 8); add_config(8, 2);
add_config(2, 4); add_config(4, 2);
add_config(2, 2);
add_config(1, 8); add_config(8, 1);
add_config(1, 4); add_config(4, 1);
add_config(1, 2); add_config(2, 1);
add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(),
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings("transpose"
",type=", type_name,
",load_size=", load_size,
",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_transpose_cu;
code = regex_replace(code, "__TYPE__", type_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", block_size);
rtc_manager.compile(kernel_label,
"transpose_optimized_kernel",
code,
"transformer_engine/common/transpose/rtc/transpose.cu");
}
rtc_manager.launch(kernel_label,
num_blocks, block_size, 0, stream,
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type*>(output.data.dptr),
row_length, num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP;
const int num_blocks = (DIVUP(row_length, row_tile_size)
* DIVUP(num_rows, col_tile_size));
transpose_general_kernel<load_size, store_size, Type><<<num_blocks, block_size, 0, stream>>>(
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr),
row_length, num_rows);
}
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
input.data.dtype, Type, constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type);
// Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows, type_size, load_size, store_size);
};
add_config(8, 8);
add_config(4, 8);
add_config(8, 4);
add_config(4, 4);
add_config(2, 8);
add_config(8, 2);
add_config(2, 4);
add_config(4, 2);
add_config(2, 2);
add_config(1, 8);
add_config(8, 1);
add_config(1, 4);
add_config(4, 1);
add_config(1, 2);
add_config(2, 1);
add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(), kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch
auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings(
"transpose"
",type=",
type_name, ",load_size=", load_size, ",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_transpose_cu;
code = regex_replace(code, "__TYPE__", type_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", block_size);
rtc_manager.compile(kernel_label, "transpose_optimized_kernel", code,
"transformer_engine/common/transpose/rtc/transpose.cu");
}
rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr), row_length, num_rows);
} else { // Statically-compiled general kernel
constexpr size_t load_size = 4;
constexpr size_t store_size = 4;
constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP;
const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
transpose_general_kernel<load_size, store_size, Type>
<<<num_blocks, block_size, 0, stream>>>(static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr),
row_length, num_rows);
}); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_transpose(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor*>(input),
noop,
reinterpret_cast<Tensor*>(output),
transpose(*reinterpret_cast<const Tensor *>(input), noop, reinterpret_cast<Tensor *>(output),
stream);
}
void nvte_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor output,
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(noop),
reinterpret_cast<Tensor*>(output),
stream);
transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(output), stream);
}
......@@ -4,18 +4,19 @@
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <cfloat>
#include <iostream>
#include <type_traits>
#include "../utils.cuh"
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
template <int nvec_in, int nvec_out,
typename IVec, typename OVec, typename CVec, typename CType>
template <int nvec_in, int nvec_out, typename IVec, typename OVec, typename CVec, typename CType>
inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*)
......@@ -24,7 +25,8 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>;
CVec step_dbias; step_dbias.clear();
CVec step_dbias;
step_dbias.clear();
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
......@@ -61,24 +63,21 @@ namespace {
template <typename IType, typename OType, typename CType>
struct TDBiasParam {
using InputType = IType;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
OType *output_t;
const CType *scale_inv;
CType *workspace;
using InputType = IType;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
OType *output_t;
const CType *scale_inv;
CType *workspace;
};
} // namespace
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
transpose_dbias_kernel(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
__global__ void __launch_bounds__(cast_transpose_num_threads)
transpose_dbias_kernel(const Param param, const size_t row_length, const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
......@@ -92,27 +91,24 @@ transpose_dbias_kernel(const Param param,
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
const size_t tile_id =
blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
const IType *const my_input_tile =
param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
OType *const my_output_t_tile =
param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
CType *const my_partial_dbias_tile =
param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
......@@ -123,9 +119,8 @@ transpose_dbias_kernel(const Param param,
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1;
partial_dbias.clear();
......@@ -147,11 +142,8 @@ transpose_dbias_kernel(const Param param,
}
OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs_partial_dbias(
in[current_in ^ 1],
out_trans,
partial_dbias,
scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
in[current_in ^ 1], out_trans, partial_dbias, scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
......@@ -164,14 +156,13 @@ transpose_dbias_kernel(const Param param,
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
......@@ -199,12 +190,9 @@ transpose_dbias_kernel(const Param param,
}
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
transpose_dbias_kernel_notaligned(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
__global__ void __launch_bounds__(cast_transpose_num_threads)
transpose_dbias_kernel_notaligned(const Param param, const size_t row_length,
const size_t num_rows, const size_t num_tiles) {
using IType = typename Param::InputType;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
......@@ -216,38 +204,35 @@ transpose_dbias_kernel_notaligned(const Param param,
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
const size_t num_tiles_x =
(row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id =
blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
const IType *const my_input_tile =
param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
OType *const my_output_t_tile =
param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
CType *const my_partial_dbias_tile =
param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
const unsigned int tile_length =
row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height =
row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
OVec *const my_scratch =
reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
......@@ -256,16 +241,14 @@ transpose_dbias_kernel_notaligned(const Param param,
CVec partial_dbias;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
unsigned int my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1;
partial_dbias.clear();
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
......@@ -280,8 +263,8 @@ transpose_dbias_kernel_notaligned(const Param param,
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
const bool valid_load =
my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
......@@ -294,11 +277,8 @@ transpose_dbias_kernel_notaligned(const Param param,
}
OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs_partial_dbias(
in[current_in ^ 1],
out_trans,
partial_dbias,
scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
in[current_in ^ 1], out_trans, partial_dbias, scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
......@@ -311,14 +291,13 @@ transpose_dbias_kernel_notaligned(const Param param,
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
my_place =
(my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
......@@ -352,27 +331,25 @@ transpose_dbias_kernel_notaligned(const Param param,
constexpr size_t reduce_dbias_num_threads = 256;
template<int nvec, typename ComputeType, typename OutputType>
__global__ void
__launch_bounds__(reduce_dbias_num_threads)
reduce_dbias_kernel(OutputType* const dbias_output,
const ComputeType* const dbias_partial,
const int row_length,
const int num_rows) {
template <int nvec, typename ComputeType, typename OutputType>
__global__ void __launch_bounds__(reduce_dbias_num_threads)
reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial,
const int row_length, const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>;
using OutputVec = Vec<OutputType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= row_length) return;
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec;
const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec;
OutputType *const thread_out_base = dbias_output + thread_id * nvec;
const int stride_in_vec = row_length / nvec;
ComputeVec ldg_vec;
ComputeVec acc_vec; acc_vec.clear();
ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll
......@@ -381,7 +358,7 @@ reduce_dbias_kernel(OutputType* const dbias_output,
}
}
OutputVec stg_vec;
OutputVec stg_vec;
#pragma unroll
for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
......@@ -390,10 +367,9 @@ reduce_dbias_kernel(OutputType* const dbias_output,
}
void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
Tensor* workspace,
const int nvec_out) {
Tensor *workspace, const int nvec_out) {
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
const size_t num_rows = input.data.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
......@@ -405,37 +381,28 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
}
template <typename BiasType>
void reduce_dbias(const Tensor &workspace, Tensor *dbias,
const size_t row_length, const size_t num_rows, const int nvec_out,
cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType);
void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length,
const size_t num_rows, const int nvec_out, cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType);
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows = DIVUP(num_rows,
static_cast<size_t>(nvec_out *
THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks = DIVUP(row_length,
reduce_dbias_num_threads * reduce_dbias_nvec);
const size_t reduce_dbias_num_rows =
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks =
DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, fp32, BiasType>
<<<reduce_dbias_num_blocks,
reduce_dbias_num_threads,
0,
stream>>>(
reinterpret_cast<BiasType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr),
reduce_dbias_row_length,
reduce_dbias_num_rows);
<<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>(
reinterpret_cast<BiasType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_num_rows);
}
void fp8_transpose_dbias(const Tensor &input,
Tensor *transposed_output,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias,
Tensor *workspace, cudaStream_t stream) {
CheckInputTensor(input, "fp8_transpose_dbias_input");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
......@@ -449,82 +416,71 @@ void fp8_transpose_dbias(const Tensor &input,
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.dtype == input.data.dtype,
"T output must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dbias->data.dtype, BiasType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, Type,
constexpr int type_size = sizeof(Type);
constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size;
if (workspace->data.dptr == nullptr) {
populate_transpose_dbias_workspace_config(input, workspace, nvec_out);
return;
}
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) *
sizeof(Vec<Type, nvec_out>);
constexpr size_t shared_size_dbias = cast_transpose_num_threads *
sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = TDBiasParam<Type, Type, ComputeType>;
Param param;
param.input = reinterpret_cast<const Type *>(input.data.dptr);
param.output_t = reinterpret_cast<Type *>(transposed_output->data.dptr);
param.scale_inv = reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
}
reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
); // NOLINT(*)
); // NOLINT(*)
"T output must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{row_length}, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
dbias->data.dtype, BiasType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, Type, constexpr int type_size = sizeof(Type);
constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size;
if (workspace->data.dptr == nullptr) {
populate_transpose_dbias_workspace_config(input, workspace, nvec_out);
return;
}
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles =
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
using ComputeType = fp32; constexpr size_t shared_size_transpose =
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>);
constexpr size_t shared_size_dbias =
cast_transpose_num_threads * sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = TDBiasParam<Type, Type, ComputeType>; Param param;
param.input = reinterpret_cast<const Type *>(input.data.dptr);
param.output_t = reinterpret_cast<Type *>(transposed_output->data.dptr);
param.scale_inv =
reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles);
}
reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out,
stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_fp8_transpose_dbias(const NVTETensor input,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_transpose_dbias);
using namespace transformer_engine;
fp8_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
fp8_transpose_dbias(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
}
......@@ -5,9 +5,10 @@
************************************************************************/
#include <transformer_engine/cast.h>
#include "../common.h"
#include "../utils.cuh"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
namespace transformer_engine {
......@@ -15,9 +16,7 @@ namespace detail {
struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) {
return value;
}
__device__ inline fp32 identity(fp32 value, const Empty &) { return value; }
struct DequantizeParam {
const fp32 *scale_inv;
......@@ -29,83 +28,63 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param)
} // namespace detail
void fp8_quantize(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(!is_fp8_dtype(input.data.dtype),
"Input must be in higher precision.");
NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision.");
NVTE_CHECK(is_fp8_dtype(output->data.dtype),
"Output must have FP8 type.");
NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
N,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), N, {},
stream);); // NOLINT(*)
); // NOLINT(*)
}
void fp8_dequantize(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(is_fp8_dtype(input.data.dtype),
"Input must have FP8 type.");
NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype),
"Output must be in higher precision.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32*>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
nullptr,
nullptr,
N,
p,
stream);
); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, N, p,
stream);); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_fp8_quantize(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
fp8_quantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
stream);
}
void nvte_fp8_dequantize(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
fp8_dequantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
stream);
}
......@@ -5,6 +5,7 @@
************************************************************************/
#include <dlfcn.h>
#include <filesystem>
#include "../common.h"
......@@ -40,27 +41,21 @@ class Library {
#endif // _WIN32 or _WIN64 or __WINDOW__
}
Library(const Library&) = delete; // move-only
Library(const Library &) = delete; // move-only
Library(Library&& other) noexcept {
swap(*this, other);
}
Library(Library &&other) noexcept { swap(*this, other); }
Library& operator=(Library other) noexcept {
Library &operator=(Library other) noexcept {
// Copy-and-swap idiom
swap(*this, other);
return *this;
}
friend void swap(Library& first, Library& second) noexcept;
friend void swap(Library &first, Library &second) noexcept;
void *get() noexcept {
return handle_;
}
void *get() noexcept { return handle_; }
const void *get() const noexcept {
return handle_;
}
const void *get() const noexcept { return handle_; }
/*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) {
......@@ -78,13 +73,13 @@ class Library {
void *handle_ = nullptr;
};
void swap(Library& first, Library& second) noexcept {
void swap(Library &first, Library &second) noexcept {
using std::swap;
swap(first.handle_, second.handle_);
}
/*! \brief Lazily-initialized shared library for CUDA driver */
Library& cuda_driver_lib() {
Library &cuda_driver_lib() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
constexpr char lib_name[] = "nvcuda.dll";
#else
......@@ -98,9 +93,7 @@ Library& cuda_driver_lib() {
namespace cuda_driver {
void *get_symbol(const char *symbol) {
return cuda_driver_lib().get_symbol(symbol);
}
void *get_symbol(const char *symbol) { return cuda_driver_lib().get_symbol(symbol); }
} // namespace cuda_driver
......
......@@ -7,10 +7,10 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#include <string>
#include <cuda.h>
#include <string>
#include "../common.h"
#include "../util/string.h"
......@@ -35,7 +35,7 @@ void *get_symbol(const char *symbol);
template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) {
using FuncT = CUresult(ArgTs...);
FuncT *func = reinterpret_cast<FuncT*>(get_symbol(symbol));
FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));
return (*func)(args...);
}
......@@ -43,23 +43,20 @@ inline CUresult call(const char *symbol, ArgTs... args) {
} // namespace transformer_engine
#define NVTE_CHECK_CUDA_DRIVER(expr) \
do { \
const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \
if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \
const char *desc_NVTE_CHECK_CUDA_DRIVER; \
::transformer_engine::cuda_driver::call( \
"cuGetErrorString", \
status_NVTE_CHECK_CUDA_DRIVER, \
&desc_NVTE_CHECK_CUDA_DRIVER); \
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \
} \
#define NVTE_CHECK_CUDA_DRIVER(expr) \
do { \
const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \
if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \
const char *desc_NVTE_CHECK_CUDA_DRIVER; \
::transformer_engine::cuda_driver::call("cuGetErrorString", status_NVTE_CHECK_CUDA_DRIVER, \
&desc_NVTE_CHECK_CUDA_DRIVER); \
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \
} \
} while (false)
#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \
do { \
NVTE_CHECK_CUDA_DRIVER( \
::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \
#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \
do { \
NVTE_CHECK_CUDA_DRIVER(::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
......@@ -4,12 +4,13 @@
* See LICENSE for license information.
************************************************************************/
#include "../util/cuda_runtime.h"
#include <filesystem>
#include <mutex>
#include "../common.h"
#include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
namespace transformer_engine {
......@@ -24,7 +25,7 @@ namespace {
} // namespace
int num_devices() {
auto query_num_devices = [] () -> int {
auto query_num_devices = []() -> int {
int count;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&count));
return count;
......@@ -54,10 +55,10 @@ int sm_arch(int device_id) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () {
auto init = [&]() {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = 10*prop.major + prop.minor;
cache[device_id] = 10 * prop.major + prop.minor;
};
std::call_once(flags[device_id], init);
return cache[device_id];
......@@ -70,7 +71,7 @@ int sm_count(int device_id) {
device_id = current_device();
}
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () {
auto init = [&]() {
cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = prop.multiProcessorCount;
......@@ -90,12 +91,11 @@ const std::string &include_directory(bool required) {
if (need_to_check_env) {
// Search for CUDA headers in common paths
using Path = std::filesystem::path;
std::vector<std::pair<std::string, Path>> search_paths = {
{"NVTE_CUDA_INCLUDE_DIR", ""},
{"CUDA_HOME", ""},
{"CUDA_DIR", ""},
{"", string_path_cuda_include},
{"", "/usr/local/cuda"}};
std::vector<std::pair<std::string, Path>> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""},
{"CUDA_HOME", ""},
{"CUDA_DIR", ""},
{"", string_path_cuda_include},
{"", "/usr/local/cuda"}};
for (auto &[env, p] : search_paths) {
if (p.empty()) {
p = getenv<Path>(env.c_str());
......@@ -131,10 +131,11 @@ const std::string &include_directory(bool required) {
message += p;
}
}
message += (". "
"Specify path to CUDA Toolkit headers "
"with NVTE_CUDA_INCLUDE_DIR "
"or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
message +=
(". "
"Specify path to CUDA Toolkit headers "
"with NVTE_CUDA_INCLUDE_DIR "
"or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
NVTE_ERROR(message);
}
need_to_check_env = false;
......
......@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_
#include <cuda_runtime_api.h>
#include <string>
namespace transformer_engine {
......
......@@ -7,83 +7,76 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#include <stdexcept>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <nvrtc.h>
#include <stdexcept>
#include "../util/string.h"
#define NVTE_ERROR(...) \
do { \
throw ::std::runtime_error( \
::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, \
" in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__))); \
#define NVTE_ERROR(...) \
do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__))); \
} while (false)
#define NVTE_CHECK(expr, ...) \
do { \
if (!(expr)) { \
NVTE_ERROR("Assertion failed: " #expr ". ", \
::transformer_engine::concat_strings(__VA_ARGS__)); \
} \
#define NVTE_CHECK(expr, ...) \
do { \
if (!(expr)) { \
NVTE_ERROR("Assertion failed: " #expr ". ", \
::transformer_engine::concat_strings(__VA_ARGS__)); \
} \
} while (false)
#define NVTE_CHECK_CUDA(expr) \
do { \
const cudaError_t status_NVTE_CHECK_CUDA = (expr); \
if (status_NVTE_CHECK_CUDA != cudaSuccess) { \
NVTE_ERROR("CUDA Error: ", \
cudaGetErrorString(status_NVTE_CHECK_CUDA)); \
} \
#define NVTE_CHECK_CUDA(expr) \
do { \
const cudaError_t status_NVTE_CHECK_CUDA = (expr); \
if (status_NVTE_CHECK_CUDA != cudaSuccess) { \
NVTE_ERROR("CUDA Error: ", cudaGetErrorString(status_NVTE_CHECK_CUDA)); \
} \
} while (false)
#define NVTE_CHECK_CUBLAS(expr) \
do { \
const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLAS Error: ", \
cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
} \
#define NVTE_CHECK_CUBLAS(expr) \
do { \
const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
} \
} while (false)
#define NVTE_CHECK_CUDNN(expr) \
do { \
const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \
if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \
NVTE_ERROR("cuDNN Error: ", \
cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \
". " \
"For more information, enable cuDNN error logging " \
"by setting CUDNN_LOGERR_DBG=1 and " \
"CUDNN_LOGDEST_DBG=stderr in the environment."); \
} \
#define NVTE_CHECK_CUDNN(expr) \
do { \
const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \
if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \
NVTE_ERROR("cuDNN Error: ", cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \
". " \
"For more information, enable cuDNN error logging " \
"by setting CUDNN_LOGERR_DBG=1 and " \
"CUDNN_LOGDEST_DBG=stderr in the environment."); \
} \
} while (false)
#define NVTE_CHECK_CUDNN_FE(expr) \
do { \
const auto error = (expr); \
if (error.is_bad()) { \
NVTE_ERROR("cuDNN Error: ", \
error.err_msg, \
". " \
"For more information, enable cuDNN error logging " \
"by setting CUDNN_LOGERR_DBG=1 and " \
"CUDNN_LOGDEST_DBG=stderr in the environment."); \
} \
#define NVTE_CHECK_CUDNN_FE(expr) \
do { \
const auto error = (expr); \
if (error.is_bad()) { \
NVTE_ERROR("cuDNN Error: ", error.err_msg, \
". " \
"For more information, enable cuDNN error logging " \
"by setting CUDNN_LOGERR_DBG=1 and " \
"CUDNN_LOGDEST_DBG=stderr in the environment."); \
} \
} while (false)
#define NVTE_CHECK_NVRTC(expr) \
do { \
const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \
if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \
NVTE_ERROR("NVRTC Error: ", \
nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \
} \
#define NVTE_CHECK_NVRTC(expr) \
do { \
const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \
if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \
NVTE_ERROR("NVRTC Error: ", nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
......@@ -13,75 +13,73 @@ struct Empty {};
template <typename OType, typename IType>
__device__ inline OType gelu(const IType val, const Empty&) {
const float cval = val;
return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
const float cval = val;
return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
}
template <typename OType, typename IType>
__device__ inline OType dgelu(const IType val, const Empty&) {
const float cval = val;
const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
const float cval = val;
const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) {
const float cval = val;
return 1.f / (1.f + expf(-cval));
const float cval = val;
return 1.f / (1.f + expf(-cval));
}
template <typename OType, typename IType>
__device__ inline OType dsigmoid(const IType val, const Empty& e) {
const float cval = val;
const float s = sigmoid<float, float>(cval, e);
return s * (1.f - s);
const float cval = val;
const float s = sigmoid<float, float>(cval, e);
return s * (1.f - s);
}
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e);
const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) + sigmoid<float, float>(1.702f * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType silu(const IType val, const Empty& e) {
const float cval = val;
return cval * sigmoid<float, float>(cval, e);
const float cval = val;
return cval * sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty &) {
return fmaxf(value, 0.f);
__device__ inline OType relu(IType value, const Empty&) {
return fmaxf(value, 0.f);
}
template <typename OType, typename IType>
__device__ inline OType drelu(IType value, const Empty &) {
return value > 0.f ? 1.f : 0.f;
__device__ inline OType drelu(IType value, const Empty&) {
return value > 0.f ? 1.f : 0.f;
}
template <typename OType, typename IType>
__device__ inline OType srelu(IType value, const Empty &) {
return value > 0 ? value * value : 0.f;
__device__ inline OType srelu(IType value, const Empty&) {
return value > 0 ? value * value : 0.f;
}
template <typename OType, typename IType>
__device__ inline OType dsrelu(IType value, const Empty &) {
return fmaxf(2.f * value, 0.f);
__device__ inline OType dsrelu(IType value, const Empty&) {
return fmaxf(2.f * value, 0.f);
}
} // namespace transformer_engine
......
......@@ -4,6 +4,8 @@
* See LICENSE for license information.
************************************************************************/
#include "../util/rtc.h"
#include <cstdlib>
#include <iostream>
#include <utility>
......@@ -13,8 +15,6 @@
#include "../util/string.h"
#include "../util/system.h"
#include "../util/rtc.h"
namespace transformer_engine {
namespace rtc {
......@@ -22,8 +22,8 @@ namespace rtc {
namespace {
// Strings with headers for RTC kernels
#include "string_code_utils_cuh.h"
#include "string_code_util_math_h.h"
#include "string_code_utils_cuh.h"
/*! \brief Latest compute capability that NVRTC supports
*
......@@ -56,29 +56,25 @@ bool is_enabled() {
}
Kernel::Kernel(std::string mangled_name, std::string compiled_code)
: mangled_name_{std::move(mangled_name)}
, compiled_code_{std::move(compiled_code)}
, modules_(cuda::num_devices(), null_module)
, functions_(cuda::num_devices(), null_function)
, init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} {
}
: mangled_name_{std::move(mangled_name)},
compiled_code_{std::move(compiled_code)},
modules_(cuda::num_devices(), null_module),
functions_(cuda::num_devices(), null_function),
init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} {}
Kernel::~Kernel() {
for (int device_id=0; device_id<static_cast<int>(modules_.size()); ++device_id) {
for (int device_id = 0; device_id < static_cast<int>(modules_.size()); ++device_id) {
// Unload CUDA modules if needed
if (modules_[device_id] != null_module) {
CUdevice device;
CUcontext context;
if (cuda_driver::call("cuDeviceGet", &device, device_id)
!= CUDA_SUCCESS) {
if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) {
continue;
}
if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device)
!= CUDA_SUCCESS) {
if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) != CUDA_SUCCESS) {
continue;
}
if (cuda_driver::call("cuCtxSetCurrent", context)
!= CUDA_SUCCESS) {
if (cuda_driver::call("cuCtxSetCurrent", context) != CUDA_SUCCESS) {
continue;
}
cuda_driver::call("cuModuleUnload", modules_[device_id]);
......@@ -87,9 +83,7 @@ Kernel::~Kernel() {
}
}
Kernel::Kernel(Kernel&& other) noexcept {
swap(*this, other);
}
Kernel::Kernel(Kernel&& other) noexcept { swap(*this, other); }
Kernel& Kernel::operator=(Kernel other) noexcept {
// Copy-and-swap idiom
......@@ -108,7 +102,7 @@ void swap(Kernel& first, Kernel& second) noexcept {
CUfunction Kernel::get_function(int device_id) {
// Load kernel on device if needed
auto load_on_device = [&] () {
auto load_on_device = [&]() {
// Set driver context to proper device
CUdevice device;
CUcontext context;
......@@ -117,15 +111,11 @@ CUfunction Kernel::get_function(int device_id) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
// Load function into driver context
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx,
&modules_[device_id],
compiled_code_.c_str(),
0, // numOptions
nullptr, // options
nullptr); // optionValues
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction,
&functions_[device_id],
modules_[device_id],
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, &modules_[device_id], compiled_code_.c_str(),
0, // numOptions
nullptr, // options
nullptr); // optionValues
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction, &functions_[device_id], modules_[device_id],
mangled_name_.c_str());
// Reset driver context
......@@ -147,10 +137,8 @@ KernelManager& KernelManager::instance() {
return instance_;
}
void KernelManager::compile(const std::string &kernel_label,
const std::string &kernel_name,
const std::string &code,
const std::string &filename) {
void KernelManager::compile(const std::string& kernel_label, const std::string& kernel_name,
const std::string& code, const std::string& filename) {
std::lock_guard<std::mutex> lock_guard_(lock_);
// Choose whether to compile to PTX or cubin
......@@ -162,9 +150,9 @@ void KernelManager::compile(const std::string &kernel_label,
// Compilation flags
std::vector<std::string> opts = {
#if NDEBUG == 0
"-G",
"-G",
#endif
"--std=c++17"};
"--std=c++17"};
if (compile_ptx) {
opts.push_back(concat_strings("--gpu-architecture=compute_", compile_sm_arch));
} else {
......@@ -181,20 +169,14 @@ void KernelManager::compile(const std::string &kernel_label,
constexpr int num_headers = 2;
constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program,
code.c_str(),
filename.c_str(),
num_headers,
headers,
include_names));
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers,
headers, include_names));
NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
const nvrtcResult compile_result = nvrtcCompileProgram(program,
opts_ptrs.size(),
opts_ptrs.data());
const nvrtcResult compile_result =
nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data());
if (compile_result != NVRTC_SUCCESS) {
// Display log if compilation failed
std::string log = concat_strings("NVRTC compilation log for ",
filename, ":\n");
std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n");
const size_t log_offset = log.size();
size_t log_size;
NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size));
......@@ -206,10 +188,8 @@ void KernelManager::compile(const std::string &kernel_label,
}
// Get mangled function name
const char *mangled_name;
NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program,
kernel_name.c_str(),
&mangled_name));
const char* mangled_name;
NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, kernel_name.c_str(), &mangled_name));
// Get compiled code
std::string compiled_code;
......@@ -234,20 +214,19 @@ void KernelManager::compile(const std::string &kernel_label,
NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program));
}
void KernelManager::set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config) {
void KernelManager::set_cache_config(const std::string& kernel_label, CUfunc_cache cache_config) {
const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0,
"Attempted to configure RTC kernel before compilation");
NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to configure RTC kernel before compilation");
kernel_cache_.at(key).set_function_cache_config(device_id, cache_config);
}
bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const {
bool KernelManager::is_compiled(const std::string& kernel_label, int device_id) const {
const auto key = get_kernel_cache_key(kernel_label, device_id);
return kernel_cache_.count(key) > 0;
}
std::string KernelManager::get_kernel_cache_key(const std::string &kernel_label,
std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label,
int device_id) const {
return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment