Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -13,7 +13,8 @@ using namespace transformer_engine::rmsnorm; ...@@ -13,7 +13,8 @@ using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL> int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL>
void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_tuned_(LaunchParams<BwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = using Kernel_traits =
rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE, rmsnorm::Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>; CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
...@@ -30,8 +31,8 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_ ...@@ -30,8 +31,8 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_
launch_params.workspace_bytes = 0; launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) { if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col * launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::reduce_t) * 2; sizeof(typename Kernel_traits::reduce_t) * 2;
} }
return; return;
...@@ -70,12 +71,13 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_ ...@@ -70,12 +71,13 @@ void launch_tuned_(LaunchParams<BwdParams> &launch_params, const bool configure_
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN, typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL> int BYTES_PER_LDG_FINAL>
void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_general_(LaunchParams<BwdParams> &launch_params,
const bool configure_params) { // NOLINT(*)
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
// Instantiate kernel // Instantiate kernel
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>; 1, WARPS_M, WARPS_N, BYTES_PER_LDG_MAIN>;
auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>; auto kernel = &rmsnorm_bwd_general_kernel<Kernel_traits>;
// Configure kernel params // Configure kernel params
...@@ -97,8 +99,8 @@ void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configur ...@@ -97,8 +99,8 @@ void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configur
launch_params.workspace_bytes = 0; launch_params.workspace_bytes = 0;
if (launch_params.params.ctas_per_row > 1) { if (launch_params.params.ctas_per_row > 1) {
launch_params.barrier_size = 2 * ctas_per_col; launch_params.barrier_size = 2 * ctas_per_col;
launch_params.workspace_bytes = (ctas_per_col * WARPS_M * ctas_per_row * launch_params.workspace_bytes =
sizeof(typename Kernel_traits::reduce_t) * 2); (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2);
} }
return; return;
} }
...@@ -130,46 +132,24 @@ void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configur ...@@ -130,46 +132,24 @@ void launch_general_(LaunchParams<BwdParams> &launch_params, const bool configur
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_TUNED_LAUNCHER( \ #define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, \ WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \ LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
&launch_params, \ launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
const bool configure_params) { \ WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, \
launch_tuned_<WTYPE, \ configure_params); \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \ } \
static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ static BwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_BWD_GENERAL_LAUNCHER( \ #define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG, \ BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
BYTES_PER_LDG_FINALIZE) \
void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<BwdParams> \ LaunchParams<BwdParams> &launch_params, const bool configure_params) { \
&launch_params, \ launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
const bool configure_params) { \ BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
launch_general_<WTYPE, \
ITYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \ } \
static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ static BwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
......
...@@ -13,9 +13,10 @@ using namespace transformer_engine::rmsnorm; ...@@ -13,9 +13,10 @@ using namespace transformer_engine::rmsnorm;
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N, typename index_t, int HIDDEN_SIZE, int CTAS_PER_ROW, int WARPS_M, int WARPS_N,
int BYTES_PER_LDG> int BYTES_PER_LDG>
void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_tuned_(LaunchParams<FwdParams> &launch_params,
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, const bool configure_params) { // NOLINT(*)
HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>; using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>; auto kernel = &rmsnorm_fwd_tuned_kernel<Kernel_traits>;
if (configure_params) { if (configure_params) {
...@@ -29,8 +30,8 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_ ...@@ -29,8 +30,8 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_
launch_params.workspace_bytes = 0; launch_params.workspace_bytes = 0;
if (Kernel_traits::CTAS_PER_ROW > 1) { if (Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col * launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M *
Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW * Kernel_traits::CTAS_PER_ROW *
sizeof(typename Kernel_traits::Stats::stats_t) * 2; sizeof(typename Kernel_traits::Stats::stats_t) * 2;
} }
return; return;
...@@ -45,8 +46,8 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_ ...@@ -45,8 +46,8 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_
auto ctas_per_row = launch_params.params.ctas_per_row; auto ctas_per_row = launch_params.params.ctas_per_row;
if (ctas_per_row == 1) { if (ctas_per_row == 1) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(
stream>>>(launch_params.params); launch_params.params);
} else { } else {
dim3 grid(ctas_per_row * ctas_per_col); dim3 grid(ctas_per_row * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA); dim3 block(Kernel_traits::THREADS_PER_CTA);
...@@ -58,9 +59,10 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_ ...@@ -58,9 +59,10 @@ void launch_tuned_(LaunchParams<FwdParams> &launch_params, const bool configure_
template <typename weight_t, typename input_t, typename output_t, typename compute_t, template <typename weight_t, typename input_t, typename output_t, typename compute_t,
typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG> typename index_t, int HIDDEN_SIZE, int WARPS_M, int WARPS_N, int BYTES_PER_LDG>
void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configure_params) { // NOLINT(*) void launch_general_(LaunchParams<FwdParams> &launch_params,
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, const bool configure_params) { // NOLINT(*)
HIDDEN_SIZE, 1, WARPS_M, WARPS_N, BYTES_PER_LDG>; using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>; auto kernel = &rmsnorm_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };
...@@ -104,27 +106,23 @@ void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configur ...@@ -104,27 +106,23 @@ void launch_general_(LaunchParams<FwdParams> &launch_params, const bool configur
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ #define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \
CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ WARPS_M, WARPS_N, BYTES_PER_LDG) \
void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \ LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
const bool configure_params) { \ launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, \
launch_tuned_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, \ WARPS_N, BYTES_PER_LDG>(launch_params, configure_params); \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \ } \
static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ static FwdTunedRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE)
#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ #define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \
WARPS_M, WARPS_N, BYTES_PER_LDG) \ BYTES_PER_LDG) \
void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<FwdParams> &launch_params, \ LaunchParams<FwdParams> &launch_params, const bool configure_params) { \
const bool configure_params) { \ launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, WARPS_M, WARPS_N, \
launch_general_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, \ BYTES_PER_LDG>(launch_params, configure_params); \
WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \ } \
static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \ static FwdGeneralRegistrar<WTYPE, ITYPE, OTYPE, CTYPE, HIDDEN_SIZE> \
reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cfloat> #include <cfloat>
#include <cstdio> #include <cstdio>
#include "../utils.cuh" #include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -164,8 +165,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -164,8 +165,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
const index_t gdimm = bdimm * params.ctas_per_col; const index_t gdimm = bdimm * params.ctas_per_col;
const index_t gdimn = bdimn * params.ctas_per_row; const index_t gdimn = bdimn * params.ctas_per_row;
const index_t gidm = bidm * bdimm + warp_m; const index_t gidm = bidm * bdimm + warp_m;
const index_t gidn = const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
(bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP +
lane); // Order threads by warp x cta x lane lane); // Order threads by warp x cta x lane
// Objects for stats reductions // Objects for stats reductions
......
...@@ -5,80 +5,65 @@ ...@@ -5,80 +5,65 @@
************************************************************************/ ************************************************************************/
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "common.h" #include "common.h"
namespace transformer_engine { namespace transformer_engine {
size_t typeToSize(const transformer_engine::DType type) { size_t typeToSize(const transformer_engine::DType type) {
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
return TypeInfo<T>::size; return TypeInfo<T>::size;); // NOLINT(*)
); // NOLINT(*)
} }
bool is_fp8_dtype(const transformer_engine::DType t) { bool is_fp8_dtype(const transformer_engine::DType t) {
return t == transformer_engine::DType::kFloat8E4M3 || return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2;
t == transformer_engine::DType::kFloat8E5M2;
} }
void CheckInputTensor(const Tensor &t, const std::string &name) { void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.data.dtype; const DType type = t.data.dtype;
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type)) {
// FP8 input needs to have scale_inv // FP8 input needs to have scale_inv
NVTE_CHECK(t.scale_inv.dptr != nullptr, NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale.");
"FP8 input " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + ".");
"Scale is not supported for non-FP8 input " + name + "."); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 input " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr, NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input " + name + "."); "Scale_inv is not supported for non-FP8 input " + name + ".");
} }
NVTE_CHECK(t.data.dptr != nullptr, NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!");
"Input " + name + " is not allocated!");
} }
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
const DType type = t.data.dtype; const DType type = t.data.dtype;
if (is_fp8_dtype(type)) { if (is_fp8_dtype(type)) {
// FP8 output needs to have scale, amax and scale_inv // FP8 output needs to have scale, amax and scale_inv
NVTE_CHECK(t.amax.dptr != nullptr, NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor.");
"FP8 output " + name + " must have amax tensor.");
NVTE_CHECK(t.amax.dtype == DType::kFloat32); NVTE_CHECK(t.amax.dtype == DType::kFloat32);
NVTE_CHECK(t.amax.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.amax.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale_inv.dptr != nullptr, NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale.");
"FP8 output " + name + " must have scale.");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
NVTE_CHECK(t.scale.dptr != nullptr, NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale.");
"FP8 output " + name + " must have inverse of scale.");
NVTE_CHECK(t.scale.dtype == DType::kFloat32); NVTE_CHECK(t.scale.dtype == DType::kFloat32);
NVTE_CHECK(t.scale.shape == std::vector<size_t>{ 1 }); NVTE_CHECK(t.scale.shape == std::vector<size_t>{1});
} else { } else {
NVTE_CHECK(t.scale.dptr == nullptr, NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + ".");
"Scale is not supported for non-FP8 output " + name + "."); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.amax.dptr == nullptr,
"Amax is not supported for non-FP8 output " + name + ".");
NVTE_CHECK(t.scale_inv.dptr == nullptr, NVTE_CHECK(t.scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 output " + name + "."); "Scale_inv is not supported for non-FP8 output " + name + ".");
} }
if (!allow_empty) { if (!allow_empty) {
NVTE_CHECK(t.data.dptr != nullptr, NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!");
"Output " + name + " is not allocated!");
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
NVTETensor nvte_create_tensor(void *dptr, NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax,
const NVTEShape shape, float *scale, float *scale_inv) {
const NVTEDType dtype,
float *amax,
float *scale,
float *scale_inv) {
transformer_engine::Tensor *ret = new transformer_engine::Tensor; transformer_engine::Tensor *ret = new transformer_engine::Tensor;
ret->data.dptr = dptr; ret->data.dptr = dptr;
ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim); ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
...@@ -97,11 +82,11 @@ void nvte_destroy_tensor(NVTETensor tensor) { ...@@ -97,11 +82,11 @@ void nvte_destroy_tensor(NVTETensor tensor) {
NVTEDType nvte_tensor_type(const NVTETensor tensor) { NVTEDType nvte_tensor_type(const NVTETensor tensor) {
return static_cast<NVTEDType>( return static_cast<NVTEDType>(
reinterpret_cast<const transformer_engine::Tensor*>(tensor)->data.dtype); reinterpret_cast<const transformer_engine::Tensor *>(tensor)->data.dtype);
} }
NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTEShape ret; NVTEShape ret;
ret.data = t.data.shape.data(); ret.data = t.data.shape.data();
ret.ndim = t.data.shape.size(); ret.ndim = t.data.shape.size();
...@@ -109,40 +94,40 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { ...@@ -109,40 +94,40 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
} }
void *nvte_tensor_data(const NVTETensor tensor) { void *nvte_tensor_data(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
return t.data.dptr; return t.data.dptr;
} }
float *nvte_tensor_amax(const NVTETensor tensor) { float *nvte_tensor_amax(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
"Tensor's amax must have Float32 type!"); "Tensor's amax must have Float32 type!");
return reinterpret_cast<float*>(t.amax.dptr); return reinterpret_cast<float *>(t.amax.dptr);
} }
float *nvte_tensor_scale(const NVTETensor tensor) { float *nvte_tensor_scale(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
"Tensor's scale must have Float32 type!"); "Tensor's scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale.dptr); return reinterpret_cast<float *>(t.scale.dptr);
} }
float *nvte_tensor_scale_inv(const NVTETensor tensor) { float *nvte_tensor_scale_inv(const NVTETensor tensor) {
const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32, NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
"Tensor's inverse of scale must have Float32 type!"); "Tensor's inverse of scale must have Float32 type!");
return reinterpret_cast<float*>(t.scale_inv.dptr); return reinterpret_cast<float *>(t.scale_inv.dptr);
} }
void nvte_tensor_pack_create(NVTETensorPack* pack) { void nvte_tensor_pack_create(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) { for (int i = 0; i < pack->MAX_SIZE; i++) {
pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor); pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
} }
} }
void nvte_tensor_pack_destroy(NVTETensorPack* pack) { void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
for (int i = 0; i < pack->MAX_SIZE; i++) { for (int i = 0; i < pack->MAX_SIZE; i++) {
auto *t = reinterpret_cast<transformer_engine::Tensor*>(pack->tensors[i]); auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
delete t; delete t;
} }
} }
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/cast_transpose_noop.h> #include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <algorithm> #include <algorithm>
#include <cuda_runtime.h>
#include "../common.h" #include "../common.h"
#include "../util/rtc.h" #include "../util/rtc.h"
#include "../util/string.h" #include "../util/string.h"
...@@ -49,26 +48,18 @@ struct KernelConfig { ...@@ -49,26 +48,18 @@ struct KernelConfig {
/* Elements per L1 cache store to transposed output */ /* Elements per L1 cache store to transposed output */
size_t elements_per_store_t = 0; size_t elements_per_store_t = 0;
KernelConfig(size_t row_length, KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t otype_size,
size_t num_rows, size_t load_size_, size_t store_size_)
size_t itype_size, : load_size{load_size_}, store_size{store_size_} {
size_t otype_size,
size_t load_size_,
size_t store_size_)
: load_size{load_size_}
, store_size{store_size_} {
// Check that tiles are correctly aligned // Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128; constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0 if (load_size % itype_size != 0 || store_size % otype_size != 0 ||
|| store_size % otype_size != 0 cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) {
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
return; return;
} }
const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size; const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size; const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size;
valid = (row_length % row_tile_elements == 0 valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
&& num_rows % col_tile_elements == 0);
if (!valid) { if (!valid) {
return; return;
} }
...@@ -80,12 +71,9 @@ struct KernelConfig { ...@@ -80,12 +71,9 @@ struct KernelConfig {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count())); static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) / itype_size);
/ itype_size); elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) / otype_size);
elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) / otype_size);
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size)
/ otype_size);
} }
/* Compare by estimated cost */ /* Compare by estimated cost */
...@@ -104,8 +92,8 @@ struct KernelConfig { ...@@ -104,8 +92,8 @@ struct KernelConfig {
const auto &st2 = other.elements_per_store_t; const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count; const auto &p2 = other.active_sm_count;
const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2; const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2;
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1) / p1; const auto cost1 = (scale / l1 + scale / sc1 + scale / st1) / p1;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2) / p2; const auto cost2 = (scale / l2 + scale / sc2 + scale / st2) / p2;
return cost1 < cost2; return cost1 < cost2;
} else { } else {
return this->valid && !other.valid; return this->valid && !other.valid;
...@@ -114,15 +102,13 @@ struct KernelConfig { ...@@ -114,15 +102,13 @@ struct KernelConfig {
}; };
template <size_t load_size, size_t store_size, typename IType, typename OType> template <size_t load_size, size_t store_size, typename IType, typename OType>
__global__ void __global__ void __launch_bounds__(block_size)
__launch_bounds__(block_size) cast_transpose_general_kernel(const IType *__restrict__ const input,
cast_transpose_general_kernel(const IType * __restrict__ const input, const CType *__restrict__ const noop,
const CType * __restrict__ const noop, OType *__restrict__ const output_c,
OType * __restrict__ const output_c, OType *__restrict__ const output_t,
OType * __restrict__ const output_t, const CType *__restrict__ const scale_ptr,
const CType * __restrict__ const scale_ptr, CType *__restrict__ const amax_ptr, const size_t row_length,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows) { const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
...@@ -165,16 +151,16 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, ...@@ -165,16 +151,16 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles, computes amax, // Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers. // casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations]; OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
const CType in = input[row * row_length + col + j2]; const CType in = input[row * row_length + col + j2];
...@@ -190,24 +176,24 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, ...@@ -190,24 +176,24 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
} }
// Copy transposed output from registers to global memory // Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter]; shared_output_t[j1][i1] = local_output_t[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out; const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2; const size_t col = tile_col + j1 * nvec_in + j2;
if (col < row_length) { if (col < row_length) {
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) { if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
...@@ -229,18 +215,15 @@ cast_transpose_general_kernel(const IType * __restrict__ const input, ...@@ -229,18 +215,15 @@ cast_transpose_general_kernel(const IType * __restrict__ const input,
} // namespace } // namespace
void cast_transpose(const Tensor &input, void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_,
const Tensor &noop, Tensor *transposed_output_, cudaStream_t stream) {
Tensor *cast_output_,
Tensor *transposed_output_,
cudaStream_t stream) {
Tensor &cast_output = *cast_output_; Tensor &cast_output = *cast_output_;
Tensor &transposed_output = *transposed_output_; Tensor &transposed_output = *transposed_output_;
// Check no-op flag // Check no-op flag
if (noop.data.dptr != nullptr) { if (noop.data.dptr != nullptr) {
size_t numel = 1; size_t numel = 1;
for (const auto& dim : noop.data.shape) { for (const auto &dim : noop.data.shape) {
numel *= dim; numel *= dim;
} }
NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, ".");
...@@ -254,16 +237,14 @@ void cast_transpose(const Tensor &input, ...@@ -254,16 +237,14 @@ void cast_transpose(const Tensor &input,
CheckOutputTensor(transposed_output, "transposed_output"); CheckOutputTensor(transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions.");
NVTE_CHECK(transposed_output.data.shape.size() == 2, NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions.");
"Transposed output must have 2 dimensions.");
const size_t row_length = input.data.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0]; const size_t num_rows = input.data.shape[0];
NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output.");
NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output.");
NVTE_CHECK(transposed_output.data.shape[0] == row_length, NVTE_CHECK(transposed_output.data.shape[0] == row_length,
"Wrong dimension of transposed output."); "Wrong dimension of transposed output.");
NVTE_CHECK(transposed_output.data.shape[1] == num_rows, NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output.");
"Wrong dimension of transposed output.");
// Check tensor pointers // Check tensor pointers
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
...@@ -276,48 +257,55 @@ void cast_transpose(const Tensor &input, ...@@ -276,48 +257,55 @@ void cast_transpose(const Tensor &input,
NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr,
"Cast and transposed outputs need to share scale tensor."); "Cast and transposed outputs need to share scale tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output.data.dtype, OutputType, input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output.data.dtype, OutputType,
constexpr const char *itype_name = TypeInfo<InputType>::name; constexpr const char *itype_name = TypeInfo<InputType>::name;
constexpr const char *otype_name = TypeInfo<OutputType>::name; constexpr const char *otype_name = TypeInfo<OutputType>::name;
constexpr size_t itype_size = sizeof(InputType); constexpr size_t itype_size = sizeof(InputType);
constexpr size_t otype_size = sizeof(OutputType); constexpr size_t otype_size = sizeof(OutputType);
// Choose between runtime-compiled or statically-compiled kernel // Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 const bool aligned =
&& num_rows % THREADS_PER_WARP == 0); (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Pick kernel config // Pick kernel config
std::vector<KernelConfig> kernel_configs; std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16); kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) { auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows, kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size,
itype_size, otype_size, store_size);
load_size, store_size);
}; };
add_config(8, 8); add_config(8, 8);
add_config(4, 8); add_config(8, 4); add_config(4, 8);
add_config(8, 4);
add_config(4, 4); add_config(4, 4);
add_config(2, 8); add_config(8, 2); add_config(2, 8);
add_config(2, 4); add_config(4, 2); add_config(8, 2);
add_config(2, 4);
add_config(4, 2);
add_config(2, 2); add_config(2, 2);
add_config(1, 8); add_config(8, 1); add_config(1, 8);
add_config(1, 4); add_config(4, 1); add_config(8, 1);
add_config(1, 2); add_config(2, 1); add_config(1, 4);
add_config(4, 1);
add_config(1, 2);
add_config(2, 1);
add_config(1, 1); add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(), const auto &kernel_config =
kernel_configs.end()); *std::min_element(kernel_configs.begin(), kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config"); NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size; const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size; const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks; const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch // Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance(); auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings("cast_transpose" const std::string kernel_label = concat_strings(
",itype=", itype_name, "cast_transpose"
",otype=", otype_name, ",itype=",
",load_size=", load_size, itype_name, ",otype=", otype_name, ",load_size=", load_size,
",store_size=", store_size); ",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) { if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_cu; std::string code = string_code_transpose_rtc_cast_transpose_cu;
...@@ -327,27 +315,23 @@ void cast_transpose(const Tensor &input, ...@@ -327,27 +315,23 @@ void cast_transpose(const Tensor &input,
code = regex_replace(code, "__STORE_SIZE__", store_size); code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", block_size); code = regex_replace(code, "__BLOCK_SIZE__", block_size);
rtc_manager.compile(kernel_label, rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code,
"cast_transpose_optimized_kernel",
code,
"transformer_engine/common/transpose/rtc/cast_transpose.cu"); "transformer_engine/common/transpose/rtc/cast_transpose.cu");
} }
rtc_manager.launch(kernel_label, rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
num_blocks, block_size, 0, stream,
static_cast<const InputType *>(input.data.dptr), static_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const CType *>(noop.data.dptr), reinterpret_cast<const CType *>(noop.data.dptr),
static_cast<OutputType*>(cast_output.data.dptr), static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType*>(transposed_output.data.dptr), static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType*>(cast_output.scale.dptr), static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType*>(cast_output.amax.dptr), static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
row_length, num_rows);
} else { // Statically-compiled general kernel } else { // Statically-compiled general kernel
constexpr size_t load_size = 4; constexpr size_t load_size = 4;
constexpr size_t store_size = 4; constexpr size_t store_size = 4;
constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP;
const int num_blocks = (DIVUP(row_length, row_tile_size) const int num_blocks =
* DIVUP(num_rows, col_tile_size)); (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
cast_transpose_general_kernel<load_size, store_size, InputType, OutputType> cast_transpose_general_kernel<load_size, store_size, InputType, OutputType>
<<<num_blocks, block_size, 0, stream>>>( <<<num_blocks, block_size, 0, stream>>>(
static_cast<const InputType *>(input.data.dptr), static_cast<const InputType *>(input.data.dptr),
...@@ -355,39 +339,29 @@ void cast_transpose(const Tensor &input, ...@@ -355,39 +339,29 @@ void cast_transpose(const Tensor &input,
static_cast<OutputType *>(cast_output.data.dptr), static_cast<OutputType *>(cast_output.data.dptr),
static_cast<OutputType *>(transposed_output.data.dptr), static_cast<OutputType *>(transposed_output.data.dptr),
static_cast<const CType *>(cast_output.scale.dptr), static_cast<const CType *>(cast_output.scale.dptr),
static_cast<CType *>(cast_output.amax.dptr), static_cast<CType *>(cast_output.amax.dptr), row_length, num_rows);
row_length, num_rows); }); // NOLINT(*)
}
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_cast_transpose(const NVTETensor input, void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output,
NVTETensor cast_output, NVTETensor transposed_output, cudaStream_t stream) {
NVTETensor transposed_output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose); NVTE_API_CALL(nvte_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor *>(input), noop,
noop, reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(transposed_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
void nvte_cast_transpose_with_noop(const NVTETensor input, void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop,
const NVTETensor noop, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_with_noop); NVTE_API_CALL(nvte_cast_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose(*reinterpret_cast<const Tensor*>(input), cast_transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
*reinterpret_cast<const Tensor*>(noop), reinterpret_cast<Tensor *>(cast_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(transposed_output), stream);
reinterpret_cast<Tensor*>(transposed_output),
stream);
} }
...@@ -4,16 +4,18 @@ ...@@ -4,16 +4,18 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <cfloat> #include <cfloat>
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include "../utils.cuh"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../common.h" #include "../common.h"
#include "../util/math.h" #include "../util/math.h"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -51,17 +53,9 @@ struct KernelConfig { ...@@ -51,17 +53,9 @@ struct KernelConfig {
size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output
size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output
KernelConfig(size_t row_length, KernelConfig(size_t row_length, size_t num_rows, size_t itype_size, size_t itype2_size,
size_t num_rows, size_t otype_size, size_t load_size_, size_t store_size_, bool is_dact_)
size_t itype_size, : load_size{load_size_}, store_size{store_size_}, is_dact{is_dact_} {
size_t itype2_size,
size_t otype_size,
size_t load_size_,
size_t store_size_,
bool is_dact_)
: load_size{load_size_}
, store_size{store_size_}
, is_dact{is_dact_} {
if (is_dact) { if (is_dact) {
if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) { if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) {
return; return;
...@@ -70,10 +64,8 @@ struct KernelConfig { ...@@ -70,10 +64,8 @@ struct KernelConfig {
// Check that tiles are correctly aligned // Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128; constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0 if (load_size % itype_size != 0 || store_size % otype_size != 0 ||
|| store_size % otype_size != 0 cache_line_size % itype_size != 0 || cache_line_size % otype_size != 0) {
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
return; return;
} }
/* row_tile_elements */ /* row_tile_elements */
...@@ -95,14 +87,10 @@ struct KernelConfig { ...@@ -95,14 +87,10 @@ struct KernelConfig {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm), active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count())); static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size) / itype_size);
/ itype_size); elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) / itype2_size);
elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size) elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size) / otype_size);
/ itype2_size); elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size) / otype_size);
elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size)
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size)
/ otype_size);
} }
/* Compare by estimated cost */ /* Compare by estimated cost */
...@@ -126,10 +114,10 @@ struct KernelConfig { ...@@ -126,10 +114,10 @@ struct KernelConfig {
const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1); const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1);
const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1); const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1);
const auto scale = scale1 * scale2; const auto scale = scale1 * scale2;
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1 + (is_dact ? (scale / la1) : 0)) const auto cost1 =
/ p1; (scale / l1 + scale / sc1 + scale / st1 + (is_dact ? (scale / la1) : 0)) / p1;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2 + (is_dact ? (scale / la2) : 0)) const auto cost2 =
/ p2; (scale / l2 + scale / sc2 + scale / st2 + (is_dact ? (scale / la2) : 0)) / p2;
return cost1 < cost2; return cost1 < cost2;
} else { } else {
...@@ -138,15 +126,13 @@ struct KernelConfig { ...@@ -138,15 +126,13 @@ struct KernelConfig {
} }
}; };
template <bool IS_DBIAS, bool IS_FULL_TILE, int nvec_in, int nvec_out, typename OVec, typename CVec,
template <bool IS_DBIAS, bool IS_FULL_TILE, int nvec_in, int nvec_out, typename CType>
typename OVec, typename CVec, typename CType>
inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in], OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*) CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile, typename OVec::type *output_cast_tile,
const size_t current_place, const size_t current_place, const size_t stride,
const size_t stride,
const CType scale, const CType scale,
CType &amax, // NOLINT(*) CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane, const int dbias_shfl_src_lane,
...@@ -159,10 +145,10 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], ...@@ -159,10 +145,10 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
step_dbias.clear(); step_dbias.clear();
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast; OVecC out_cast;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j]; const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
...@@ -180,7 +166,7 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], ...@@ -180,7 +166,7 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
} }
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
...@@ -190,8 +176,7 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], ...@@ -190,8 +176,7 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
} }
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace, Tensor *workspace, const int nvec_out) {
const int nvec_out) {
const size_t row_length = cast_output.data.shape[1]; const size_t row_length = cast_output.data.shape[1];
const size_t num_rows = cast_output.data.shape[0]; const size_t num_rows = cast_output.data.shape[0];
...@@ -204,13 +189,10 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, / ...@@ -204,13 +189,10 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace->data.dtype = DType::kFloat32; workspace->data.dtype = DType::kFloat32;
} }
template<int nvec, typename ComputeType, typename OutputType> template <int nvec, typename ComputeType, typename OutputType>
__global__ void __global__ void __launch_bounds__(reduce_dbias_num_threads)
__launch_bounds__(reduce_dbias_num_threads) reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial,
reduce_dbias_kernel(OutputType* const dbias_output, const int row_length, const int num_rows) {
const ComputeType* const dbias_partial,
const int row_length,
const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>; using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>; using OutputVec = Vec<OutputType, nvec>;
...@@ -220,23 +202,24 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -220,23 +202,24 @@ reduce_dbias_kernel(OutputType* const dbias_output,
return; return;
} }
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec; const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec; OutputType *const thread_out_base = dbias_output + thread_id * nvec;
const int stride_in_vec = row_length / nvec; const int stride_in_vec = row_length / nvec;
ComputeVec ldg_vec; ComputeVec ldg_vec;
ComputeVec acc_vec; acc_vec.clear(); ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < num_rows; ++i) { for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec); ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll #pragma unroll
for (int e = 0; e < nvec; ++e) { for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e]; acc_vec.data.elt[e] += ldg_vec.data.elt[e];
} }
} }
OutputVec stg_vec; OutputVec stg_vec;
#pragma unroll #pragma unroll
for (int e = 0; e < nvec; ++e) { for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
} }
...@@ -244,42 +227,32 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -244,42 +227,32 @@ reduce_dbias_kernel(OutputType* const dbias_output,
} }
template <typename InputType> template <typename InputType>
void reduce_dbias(const Tensor &workspace, void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length,
Tensor *dbias, const size_t num_rows, const int nvec_out, cudaStream_t stream) {
const size_t row_length,
const size_t num_rows,
const int nvec_out,
cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64 constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType); constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType);
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_row_length = row_length; const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows = DIVUP(num_rows, const size_t reduce_dbias_num_rows =
static_cast<size_t>(nvec_out * THREADS_PER_WARP)); DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks = DIVUP(row_length, const size_t reduce_dbias_num_blocks =
reduce_dbias_num_threads * reduce_dbias_nvec); DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec);
using DbiasOutputType = fp32; using DbiasOutputType = fp32;
reduce_dbias_kernel<reduce_dbias_nvec, DbiasOutputType, InputType> reduce_dbias_kernel<reduce_dbias_nvec, DbiasOutputType, InputType>
<<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>> <<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>(
(reinterpret_cast<InputType *>(dbias->data.dptr), reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
} }
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param, int nvec_in,
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param, int nvec_out, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
int nvec_in, int nvec_out, typename ParamOP, __global__ void __launch_bounds__(cast_transpose_num_threads)
ComputeType (*OP)(ComputeType, const ParamOP&)> cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length,
__global__ void const size_t num_rows, const size_t num_tiles) {
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_fused_kernel_notaligned(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType; using IType = typename Param::InputType;
using IType2 = typename Param::InputType2; using IType2 = typename Param::InputType2;
using OType = typename Param::OutputType; using OType = typename Param::OutputType;
...@@ -293,10 +266,10 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -293,10 +266,10 @@ cast_transpose_fused_kernel_notaligned(const Param param,
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) const size_t num_tiles_x =
/ (nvec_in * THREADS_PER_WARP); (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) const size_t tile_id =
+ warp_id / n_warps_per_tile; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) { if (tile_id >= num_tiles) {
return; return;
} }
...@@ -304,33 +277,32 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -304,33 +277,32 @@ cast_transpose_fused_kernel_notaligned(const Param param,
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) const size_t tile_offset =
* THREADS_PER_WARP; (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
const size_t tile_offset_transp = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) const size_t tile_offset_transp =
* THREADS_PER_WARP; (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset; const IType *const my_input_tile = param.input + tile_offset;
const IType2 * const my_act_input_tile = param.act_input + tile_offset; const IType2 *const my_act_input_tile = param.act_input + tile_offset;
OType * const my_output_c_tile = param.output_c + tile_offset; OType *const my_output_c_tile = param.output_c + tile_offset;
OType * const my_output_t_tile = param.output_t + tile_offset_transp; OType *const my_output_t_tile = param.output_t + tile_offset_transp;
CType * const my_partial_dbias_tile = param.workspace CType *const my_partial_dbias_tile =
+ (tile_id_x * (nvec_in * THREADS_PER_WARP) param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
+ tile_id_y * row_length);
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_length =
: row_length_rest; row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_height =
: row_height_rest; row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) OVec *const my_scratch =
+ (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) reinterpret_cast<OVec *>(scratch) +
* (THREADS_PER_WARP + 1); (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out]; IVec in[2][nvec_out];
IVec2 act_in[2][nvec_out]; IVec2 act_in[2][nvec_out];
...@@ -340,8 +312,8 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -340,8 +312,8 @@ cast_transpose_fused_kernel_notaligned(const Param param,
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out; size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) unsigned int my_place =
% THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
CType amax = 0; CType amax = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
...@@ -351,9 +323,8 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -351,9 +323,8 @@ cast_transpose_fused_kernel_notaligned(const Param param,
} }
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
warp_id_in_tile * n_iterations < tile_height; #pragma unroll
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) { if (valid_load) {
const size_t ld_offset = current_stride + my_place + stride * i; const size_t ld_offset = current_stride + my_place + stride * i;
...@@ -370,15 +341,15 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -370,15 +341,15 @@ cast_transpose_fused_kernel_notaligned(const Param param,
} }
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place; const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length && const bool valid_load =
warp_id_in_tile * n_iterations + i + 1 < tile_height; my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) { if (valid_load) {
const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j); const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
...@@ -395,27 +366,27 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -395,27 +366,27 @@ cast_transpose_fused_kernel_notaligned(const Param param,
} }
} }
CVec after_dact[nvec_out]; // NOLINT(*) CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
* OP(act_in[current_in ^ 1][j].data.elt[k], {}); OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else { } else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
} }
} }
} }
const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations) const int dbias_shfl_src_lane =
% THREADS_PER_WARP; (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
constexpr bool IS_FULL_TILE = false; constexpr bool IS_FULL_TILE = false;
const bool valid_store = (my_place < tile_length) const bool valid_store =
&& (warp_id_in_tile * n_iterations + i < tile_height); (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE> cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_space[i], partial_dbias,
(after_dact, out_space[i], partial_dbias, my_output_c_tile, current_place, my_output_c_tile, current_place, stride, scale,
stride, scale, amax, dbias_shfl_src_lane, valid_store); amax, dbias_shfl_src_lane, valid_store);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride; current_stride += nvec_out * stride;
...@@ -423,16 +394,15 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -423,16 +394,15 @@ cast_transpose_fused_kernel_notaligned(const Param param,
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
% THREADS_PER_WARP] = out_space[j][i]; THREADS_PER_WARP] = out_space[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) my_place =
% THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
+ warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height; const bool valid_store = my_place < tile_height;
if (valid_store) { if (valid_store) {
...@@ -449,10 +419,10 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -449,10 +419,10 @@ cast_transpose_fused_kernel_notaligned(const Param param,
my_dbias_scratch[threadIdx.x] = partial_dbias; my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads(); __syncthreads();
if (warp_id_in_tile == 0) { if (warp_id_in_tile == 0) {
#pragma unroll #pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) { for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j]; partial_dbias.data.elt[j] += tmp.data.elt[j];
} }
...@@ -474,7 +444,7 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -474,7 +444,7 @@ cast_transpose_fused_kernel_notaligned(const Param param,
} }
} }
static const char* ActTypeToString[] = { static const char *ActTypeToString[] = {
"NoAct", // 0 "NoAct", // 0
"Sigmoid", // 1 "Sigmoid", // 1
"GeLU", // 2 "GeLU", // 2
...@@ -484,8 +454,7 @@ static const char* ActTypeToString[] = { ...@@ -484,8 +454,7 @@ static const char* ActTypeToString[] = {
"SReLU" // 6 "SReLU" // 6
}; };
template <typename ComputeType, typename ParamOP, template <typename ComputeType, typename ParamOP, ComputeType (*OP)(ComputeType, const ParamOP &)>
ComputeType (*OP)(ComputeType, const ParamOP&)>
int get_dactivation_type() { int get_dactivation_type() {
if (OP == &sigmoid<ComputeType, ComputeType>) { if (OP == &sigmoid<ComputeType, ComputeType>) {
return 1; return 1;
...@@ -505,13 +474,9 @@ int get_dactivation_type() { ...@@ -505,13 +474,9 @@ int get_dactivation_type() {
} }
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)> ComputeType (*OP)(ComputeType, const ParamOP &)>
void cast_transpose_fused(const Tensor &input, void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output,
const Tensor &act_input, Tensor *transposed_output, Tensor *dbias, Tensor *workspace,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_fused_input"); CheckInputTensor(input, "cast_transpose_fused_input");
CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*cast_output, "cast_output");
...@@ -537,10 +502,8 @@ void cast_transpose_fused(const Tensor &input, ...@@ -537,10 +502,8 @@ void cast_transpose_fused(const Tensor &input,
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
CheckOutputTensor(*dbias, "dbias"); CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
"DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{row_length}, "Wrong shape of DBias.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length },
"Wrong shape of DBias.");
} }
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
CheckInputTensor(act_input, "act_input"); CheckInputTensor(act_input, "act_input");
...@@ -548,17 +511,18 @@ void cast_transpose_fused(const Tensor &input, ...@@ -548,17 +511,18 @@ void cast_transpose_fused(const Tensor &input,
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match."); NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
} }
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, input.data.dtype, InputType,
using InputType2 = InputType; TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output->data.dtype, OutputType, using InputType2 = InputType;
using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>; using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>;
constexpr int itype_size = sizeof(InputType); constexpr int itype_size = sizeof(InputType);
constexpr int itype2_size = sizeof(InputType2); constexpr int itype2_size = sizeof(InputType2);
constexpr int otype_size = sizeof(OutputType); constexpr int otype_size = sizeof(OutputType);
const bool aligned = (row_length % THREADS_PER_WARP == 0) const bool aligned =
&& (num_rows % THREADS_PER_WARP == 0); (row_length % THREADS_PER_WARP == 0) && (num_rows % THREADS_PER_WARP == 0);
const bool jit_compiled = aligned && rtc::is_enabled(); const bool jit_compiled = aligned && rtc::is_enabled();
size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size); size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size);
...@@ -570,25 +534,29 @@ void cast_transpose_fused(const Tensor &input, ...@@ -570,25 +534,29 @@ void cast_transpose_fused(const Tensor &input,
std::vector<KernelConfig> kernel_configs; std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16); kernel_configs.reserve(16);
auto add_config = [&](size_t load_size_config, size_t store_size_config) { auto add_config = [&](size_t load_size_config, size_t store_size_config) {
kernel_configs.emplace_back(row_length, num_rows, kernel_configs.emplace_back(row_length, num_rows, itype_size, itype2_size, otype_size,
itype_size, itype2_size, otype_size, load_size_config, store_size_config, IS_DACT);
load_size_config, store_size_config,
IS_DACT);
}; };
add_config(8, 8); add_config(8, 8);
add_config(4, 8); add_config(8, 4); add_config(4, 8);
add_config(8, 4);
add_config(4, 4); add_config(4, 4);
add_config(2, 8); add_config(8, 2); add_config(2, 8);
add_config(2, 4); add_config(4, 2); add_config(8, 2);
add_config(2, 4);
add_config(4, 2);
add_config(2, 2); add_config(2, 2);
add_config(1, 8); add_config(8, 1); add_config(1, 8);
add_config(1, 4); add_config(4, 1); add_config(8, 1);
add_config(1, 2); add_config(2, 1); add_config(1, 4);
add_config(4, 1);
add_config(1, 2);
add_config(2, 1);
add_config(1, 1); add_config(1, 1);
// Select the kernel configuration with the lowest cost // Select the kernel configuration with the lowest cost
const auto &kernel_config = *std::min_element(kernel_configs.begin(), const auto &kernel_config =
kernel_configs.end()); *std::min_element(kernel_configs.begin(), kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config"); NVTE_CHECK(kernel_config.valid, "invalid kernel config");
load_size = kernel_config.load_size; load_size = kernel_config.load_size;
store_size = kernel_config.store_size; store_size = kernel_config.store_size;
...@@ -608,32 +576,45 @@ void cast_transpose_fused(const Tensor &input, ...@@ -608,32 +576,45 @@ void cast_transpose_fused(const Tensor &input,
if (!jit_compiled) { if (!jit_compiled) {
num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block);
} } if constexpr (IS_DBIAS) {
if constexpr (IS_DBIAS) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
workspace, nvec_out);
return; return;
} }
} }
size_t VecOutputTypeSize; size_t VecOutputTypeSize;
switch (nvec_out) { switch (nvec_out) {
case 1: VecOutputTypeSize = sizeof(Vec<OutputType, 1>); break; case 1:
case 2: VecOutputTypeSize = sizeof(Vec<OutputType, 2>); break; VecOutputTypeSize = sizeof(Vec<OutputType, 1>);
case 4: VecOutputTypeSize = sizeof(Vec<OutputType, 4>); break; break;
case 8: VecOutputTypeSize = sizeof(Vec<OutputType, 8>); break; case 2:
} VecOutputTypeSize = sizeof(Vec<OutputType, 2>);
size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile break;
* (threads_per_warp + 1) * VecOutputTypeSize; case 4:
VecOutputTypeSize = sizeof(Vec<OutputType, 4>);
break;
case 8:
VecOutputTypeSize = sizeof(Vec<OutputType, 8>);
break;
} size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(threads_per_warp + 1) * VecOutputTypeSize;
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
size_t VecComputeTypeSize; size_t VecComputeTypeSize;
switch (nvec_in) { switch (nvec_in) {
case 1: VecComputeTypeSize = sizeof(Vec<ComputeType, 1>); break; case 1:
case 2: VecComputeTypeSize = sizeof(Vec<ComputeType, 2>); break; VecComputeTypeSize = sizeof(Vec<ComputeType, 1>);
case 4: VecComputeTypeSize = sizeof(Vec<ComputeType, 4>); break; break;
case 8: VecComputeTypeSize = sizeof(Vec<ComputeType, 8>); break; case 2:
VecComputeTypeSize = sizeof(Vec<ComputeType, 2>);
break;
case 4:
VecComputeTypeSize = sizeof(Vec<ComputeType, 4>);
break;
case 8:
VecComputeTypeSize = sizeof(Vec<ComputeType, 8>);
break;
} }
const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize; const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize;
if (shared_size_transpose < shared_size_dbias) { if (shared_size_transpose < shared_size_dbias) {
...@@ -650,8 +631,7 @@ void cast_transpose_fused(const Tensor &input, ...@@ -650,8 +631,7 @@ void cast_transpose_fused(const Tensor &input,
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr); param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
} } if constexpr (IS_DACT) {
if constexpr (IS_DACT) {
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr); param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
} }
...@@ -667,17 +647,13 @@ void cast_transpose_fused(const Tensor &input, ...@@ -667,17 +647,13 @@ void cast_transpose_fused(const Tensor &input,
} }
// Compile NVRTC kernel if needed and launch // Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance(); auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = const std::string kernel_label = concat_strings(
concat_strings("cast_transpose_fusion" "cast_transpose_fusion"
",itype=", itype_name, ",itype=",
",itype2=", itype2_name, itype_name, ",itype2=", itype2_name, ",otype=", otype_name,
",otype=", otype_name, ",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS,
",load_size=", load_size, ",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]);
",store_size=", store_size,
",IS_DBIAS=", IS_DBIAS,
",IS_DACT=", IS_DACT,
",dactivationType=", ActTypeToString[dActType]);
if (!rtc_manager.is_compiled(kernel_label)) { if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu; std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu;
...@@ -693,22 +669,18 @@ void cast_transpose_fused(const Tensor &input, ...@@ -693,22 +669,18 @@ void cast_transpose_fused(const Tensor &input,
code = regex_replace(code, "__DACTIVATION_TYPE__", dActType); code = regex_replace(code, "__DACTIVATION_TYPE__", dActType);
rtc_manager.compile( rtc_manager.compile(
kernel_label, kernel_label, "cast_transpose_fusion_kernel_optimized", code,
"cast_transpose_fusion_kernel_optimized",
code,
"transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu"); "transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu");
} }
rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED); rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED);
rtc_manager.launch(kernel_label, rtc_manager.launch(kernel_label, num_blocks, cast_transpose_num_threads,
num_blocks, cast_transpose_num_threads, shared_size_transpose, stream, shared_size_transpose, stream, param, row_length, num_rows,
param, row_length, num_rows, num_tiles); num_tiles);
} else { // Statically-compiled general kernel } else { // Statically-compiled general kernel
constexpr size_t load_size = IS_DACT ? desired_load_size_dact : constexpr size_t load_size = IS_DACT ? desired_load_size_dact : desired_load_size;
desired_load_size; constexpr size_t store_size = IS_DACT ? desired_store_size_dact : desired_store_size;
constexpr size_t store_size = IS_DACT ? desired_store_size_dact :
desired_store_size;
constexpr size_t nvec_in = load_size / itype_size; constexpr size_t nvec_in = load_size / itype_size;
constexpr size_t nvec_out = store_size / otype_size; constexpr size_t nvec_out = store_size / otype_size;
...@@ -716,39 +688,30 @@ void cast_transpose_fused(const Tensor &input, ...@@ -716,39 +688,30 @@ void cast_transpose_fused(const Tensor &input,
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute( cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param,
<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP>, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
100); cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in,
cast_transpose_fused_kernel_notaligned nvec_out, Empty, OP>
<IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP> <<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>> param, row_length, num_rows, num_tiles);
(param, row_length, num_rows, num_tiles);
} }
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream); reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
} }); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
template <int nvec_in, int nvec_out, template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType,
typename CType, typename IType, typename OType, typename ParamOP, typename ParamOP, CType (*OP1)(CType, const ParamOP &),
CType (*OP1)(CType, const ParamOP&), CType (*OP2)(CType, const ParamOP &)>
CType (*OP2)(CType, const ParamOP&)> __global__ void __launch_bounds__(cast_transpose_num_threads)
__global__ void dgated_act_cast_transpose_kernel(const IType *const input, const IType *const act_input,
__launch_bounds__(cast_transpose_num_threads) OType *const output_c, OType *const output_t,
dgated_act_cast_transpose_kernel(const IType * const input, const CType *const scale_ptr, CType *const amax,
const IType * const act_input, CType *const scale_inv, const size_t row_length,
OType * const output_c, const size_t num_rows, const size_t num_tiles) {
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>; using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>; using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>; using CVec = Vec<CType, nvec_in>;
...@@ -758,8 +721,8 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -758,8 +721,8 @@ dgated_act_cast_transpose_kernel(const IType * const input,
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) const size_t tile_id =
+ warp_id / n_warps_per_tile; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) { if (tile_id >= num_tiles) {
return; return;
} }
...@@ -767,30 +730,26 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -767,30 +730,26 @@ dgated_act_cast_transpose_kernel(const IType * const input,
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in + const IType *const my_input_tile =
tile_id_y * row_length * nvec_out) * input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
THREADS_PER_WARP; const IType *const my_act_input_tile =
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
tile_id_y * row_length * 2 * nvec_out) * const IType *const my_gate_input_tile =
THREADS_PER_WARP; act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in + row_length;
tile_id_y * row_length * 2 * nvec_out) * OType *const my_output_c_tile_0 =
THREADS_PER_WARP + row_length; output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in + OType *const my_output_c_tile_1 =
tile_id_y * row_length * 2 * nvec_out) * output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
THREADS_PER_WARP; row_length;
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in + OType *const my_output_t_tile_0 =
tile_id_y * row_length * 2 * nvec_out) * output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
THREADS_PER_WARP + row_length; OType *const my_output_t_tile_1 =
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP +
tile_id_x * num_rows * nvec_in) * row_length * num_rows;
THREADS_PER_WARP; OVec *const my_scratch =
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out + reinterpret_cast<OVec *>(scratch) +
tile_id_x * num_rows * nvec_in) * (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
THREADS_PER_WARP + row_length * num_rows;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out]; IVec in[2][nvec_out];
IVec act_in[2][nvec_out]; IVec act_in[2][nvec_out];
...@@ -803,9 +762,8 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -803,9 +762,8 @@ dgated_act_cast_transpose_kernel(const IType * const input,
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - unsigned int my_place =
warp_id_in_tile * n_iterations) % (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
THREADS_PER_WARP;
const size_t stride2 = 2 * row_length / nvec_in; const size_t stride2 = 2 * row_length / nvec_in;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
CType max = 0; CType max = 0;
...@@ -813,19 +771,19 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -813,19 +771,19 @@ dgated_act_cast_transpose_kernel(const IType * const input,
CVec partial_dbias; CVec partial_dbias;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place; const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile, in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j)); current_stride + my_place_in + stride * (nvec_out + j));
...@@ -837,9 +795,9 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -837,9 +795,9 @@ dgated_act_cast_transpose_kernel(const IType * const input,
} }
CVec after_dact[nvec_out]; // NOLINT(*) CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*) CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) * CType(in[current_in ^ 1][j].data.elt[k]) *
...@@ -856,15 +814,15 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -856,15 +814,15 @@ dgated_act_cast_transpose_kernel(const IType * const input,
constexpr bool valid_store = true; constexpr bool valid_store = true;
constexpr int dbias_shfl_src_lane = 0; constexpr int dbias_shfl_src_lane = 0;
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE> cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_trans_0, partial_dbias,
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store); scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE> cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dgate, out_trans_1, partial_dbias,
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store); scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec; out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec; out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
...@@ -875,16 +833,15 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -875,16 +833,15 @@ dgated_act_cast_transpose_kernel(const IType * const input,
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; THREADS_PER_WARP] = out_space_0[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0, my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place); current_stride + my_place);
...@@ -892,16 +849,15 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -892,16 +849,15 @@ dgated_act_cast_transpose_kernel(const IType * const input,
current_stride += output_stride * nvec_in; current_stride += output_stride * nvec_in;
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i]; THREADS_PER_WARP] = out_space_1[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1, my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place); current_stride + my_place);
...@@ -925,22 +881,15 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -925,22 +881,15 @@ dgated_act_cast_transpose_kernel(const IType * const input,
} }
} }
template <int nvec_in, int nvec_out, template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType,
typename CType, typename IType, typename OType, typename ParamOP, CType (*OP1)(CType, const ParamOP &),
typename ParamOP, CType (*OP2)(CType, const ParamOP &)>
CType (*OP1)(CType, const ParamOP&), __global__ void __launch_bounds__(cast_transpose_num_threads)
CType (*OP2)(CType, const ParamOP&)> dgated_act_cast_transpose_kernel_notaligned(const IType *const input,
__global__ void const IType *const act_input, OType *const output_c,
__launch_bounds__(cast_transpose_num_threads) OType *const output_t, const CType *const scale_ptr,
dgated_act_cast_transpose_kernel_notaligned(const IType * const input, CType *const amax, CType *const scale_inv,
const IType * const act_input, const size_t row_length, const size_t num_rows,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>; using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>; using OVec = Vec<OType, nvec_out>;
...@@ -950,48 +899,44 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -950,48 +899,44 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / const size_t num_tiles_x =
(nvec_in * THREADS_PER_WARP); (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id =
warp_id / n_warps_per_tile; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in + const IType *const my_input_tile =
tile_id_y * row_length * nvec_out) * input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
THREADS_PER_WARP; const IType *const my_act_input_tile =
const IType * const my_act_input_tile = act_input + (tile_id_x * nvec_in + act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
tile_id_y * row_length * 2 * nvec_out) * const IType *const my_gate_input_tile =
THREADS_PER_WARP; act_input + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
const IType * const my_gate_input_tile = act_input + (tile_id_x * nvec_in + row_length;
tile_id_y * row_length * 2 * nvec_out) * OType *const my_output_c_tile_0 =
THREADS_PER_WARP + row_length; output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in + OType *const my_output_c_tile_1 =
tile_id_y * row_length * 2 * nvec_out) * output_c + (tile_id_x * nvec_in + tile_id_y * row_length * 2 * nvec_out) * THREADS_PER_WARP +
THREADS_PER_WARP; row_length;
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in + OType *const my_output_t_tile_0 =
tile_id_y * row_length * 2 * nvec_out) * output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
THREADS_PER_WARP + row_length; OType *const my_output_t_tile_1 =
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out + output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP +
tile_id_x * num_rows * nvec_in) * row_length * num_rows;
THREADS_PER_WARP;
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP + row_length * num_rows;
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t stride2 = 2 * row_length / nvec_in; const size_t stride2 = 2 * row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_length =
: row_length_rest; row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_height =
: row_height_rest; row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) + OVec *const my_scratch =
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * reinterpret_cast<OVec *>(scratch) +
(THREADS_PER_WARP + 1); (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
IVec in[2][nvec_out]; IVec in[2][nvec_out];
IVec act_in[2][nvec_out]; IVec act_in[2][nvec_out];
...@@ -1003,26 +948,21 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1003,26 +948,21 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - unsigned int my_place =
warp_id_in_tile * n_iterations) % (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias; CVec partial_dbias;
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
warp_id_in_tile * n_iterations < tile_height; #pragma unroll
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) { if (valid_load) {
in[0][i].load_from(my_input_tile, in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
current_stride + my_place + stride * i); act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
act_in[0][i].load_from(my_act_input_tile, gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile,
current_stride2 + my_place + stride2 * i);
} else { } else {
in[0][i].clear(); in[0][i].clear();
act_in[0][i].clear(); act_in[0][i].clear();
...@@ -1030,24 +970,24 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1030,24 +970,24 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
} }
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place; const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
{ {
const bool valid_load = my_place_in < tile_length && const bool valid_load =
warp_id_in_tile * n_iterations + i + 1 < tile_height; my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) { if (valid_load) {
in[current_in][j].load_from(my_input_tile, in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j)); current_stride + my_place_in + stride * (nvec_out + j));
act_in[current_in][j].load_from(my_act_input_tile, act_in[current_in][j].load_from(
current_stride2 + my_place_in + stride2 * (nvec_out + j)); my_act_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile, gate_in[current_in][j].load_from(
current_stride2 + my_place_in + stride2 * (nvec_out + j)); my_gate_input_tile, current_stride2 + my_place_in + stride2 * (nvec_out + j));
} else { } else {
in[current_in][j].clear(); in[current_in][j].clear();
act_in[current_in][j].clear(); act_in[current_in][j].clear();
...@@ -1058,9 +998,9 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1058,9 +998,9 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
CVec after_dact[nvec_out]; // NOLINT(*) CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*) CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) * CType(in[current_in ^ 1][j].data.elt[k]) *
...@@ -1075,17 +1015,17 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1075,17 +1015,17 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
constexpr bool IS_DBIAS = false; constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = false; constexpr bool IS_FULL_TILE = false;
constexpr int dbias_shfl_src_lane = 0; constexpr int dbias_shfl_src_lane = 0;
const bool valid_store = (my_place < tile_length) const bool valid_store =
&& (warp_id_in_tile * n_iterations + i < tile_height); (my_place < tile_length) && (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE> cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dact, out_trans_0, partial_dbias,
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store); scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE> cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>(after_dgate, out_trans_1, partial_dbias,
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store); scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec; out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec; out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
...@@ -1096,16 +1036,15 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1096,16 +1036,15 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; THREADS_PER_WARP] = out_space_0[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height; const bool valid_store = my_place < tile_height;
if (valid_store) { if (valid_store) {
...@@ -1116,16 +1055,15 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1116,16 +1055,15 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
current_stride += output_stride * nvec_in; current_stride += output_stride * nvec_in;
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i]; THREADS_PER_WARP] = out_space_1[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height; const bool valid_store = my_place < tile_height;
if (valid_store) { if (valid_store) {
...@@ -1152,13 +1090,10 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1152,13 +1090,10 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
} }
template <typename ComputeType, typename ParamOP, template <typename ComputeType, typename ParamOP, ComputeType (*OP1)(ComputeType, const ParamOP &),
ComputeType (*OP1)(ComputeType, const ParamOP&), ComputeType (*OP2)(ComputeType, const ParamOP &)>
ComputeType (*OP2)(ComputeType, const ParamOP&)> void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input,
void dgated_act_cast_transpose(const Tensor &input, Tensor *cast_output, Tensor *transposed_output,
const Tensor &gated_act_input,
Tensor *cast_output,
Tensor *transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "dgated_act_cast_transpose_input"); CheckInputTensor(input, "dgated_act_cast_transpose_input");
CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input"); CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input");
...@@ -1190,13 +1125,13 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1190,13 +1125,13 @@ void dgated_act_cast_transpose(const Tensor &input,
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr, NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor."); "C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, input.data.dtype, InputType,
using InputType2 = InputType; TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
cast_output->data.dtype, OutputType, using InputType2 = InputType;
/* dact fusion kernel uses more registers */ /* dact fusion kernel uses more registers */
constexpr int desired_load_size_dact = 4; constexpr int desired_load_size_dact = 4;
constexpr int desired_store_size_dact = 4; constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType);
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType); constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dact / itype_size; constexpr int nvec_in = desired_load_size_dact / itype_size;
constexpr int nvec_out = desired_store_size_dact / otype_size; constexpr int nvec_out = desired_store_size_dact / otype_size;
...@@ -1215,13 +1150,12 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1215,13 +1150,12 @@ void dgated_act_cast_transpose(const Tensor &input,
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>); (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute( cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
100);
dgated_act_cast_transpose_kernel dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType,
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2> Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>( <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr), reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr), reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
...@@ -1229,16 +1163,15 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1229,16 +1163,15 @@ void dgated_act_cast_transpose(const Tensor &input,
reinterpret_cast<OutputType *>(transposed_output->data.dptr), reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
row_length, num_rows, n_tiles); n_tiles);
} else { } else {
cudaFuncSetAttribute( cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel_notaligned dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>, InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
100); dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType,
dgated_act_cast_transpose_kernel_notaligned OutputType, Empty, OP1, OP2>
<nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>( <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr), reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr), reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
...@@ -1246,24 +1179,19 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1246,24 +1179,19 @@ void dgated_act_cast_transpose(const Tensor &input,
reinterpret_cast<OutputType *>(transposed_output->data.dptr), reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr), reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr), reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), row_length, num_rows,
row_length, num_rows, n_tiles); n_tiles);
} }); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
} // namespace } // namespace
} // namespace transformer_engine } // namespace transformer_engine
using ComputeType = typename transformer_engine::fp32; using ComputeType = typename transformer_engine::fp32;
void nvte_cast_transpose_dbias(const NVTETensor input, void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output,
NVTETensor cast_output, NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace,
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias); NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1274,22 +1202,14 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -1274,22 +1202,14 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
constexpr const NVTETensor activation_input = nullptr; constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(activation_input),
*reinterpret_cast<const Tensor*>(activation_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input,
const NVTETensor act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1299,22 +1219,14 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -1299,22 +1219,14 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
constexpr auto dActivation = &dgelu<fp32, fp32>; constexpr auto dActivation = &dgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(act_input),
*reinterpret_cast<const Tensor*>(act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input,
const NVTETensor silu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1324,22 +1236,14 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, ...@@ -1324,22 +1236,14 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
constexpr auto dActivation = &dsilu<fp32, fp32>; constexpr auto dActivation = &dsilu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(silu_input),
*reinterpret_cast<const Tensor*>(silu_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input,
const NVTETensor relu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); NVTE_API_CALL(nvte_cast_transpose_dbias_drelu);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1349,22 +1253,14 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, ...@@ -1349,22 +1253,14 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
constexpr auto dActivation = &drelu<fp32, fp32>; constexpr auto dActivation = &drelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(relu_input),
*reinterpret_cast<const Tensor*>(relu_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input,
const NVTETensor srelu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1374,22 +1270,14 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, ...@@ -1374,22 +1270,14 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
constexpr auto dActivation = &dsrelu<fp32, fp32>; constexpr auto dActivation = &dsrelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(srelu_input),
*reinterpret_cast<const Tensor*>(srelu_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input,
const NVTETensor qgelu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1399,19 +1287,13 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, ...@@ -1399,19 +1287,13 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
constexpr auto dActivation = &dqgelu<fp32, fp32>; constexpr auto dActivation = &dqgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(qgelu_input),
*reinterpret_cast<const Tensor*>(qgelu_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream);
} }
void nvte_dgeglu_cast_transpose(const NVTETensor input, void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose); NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1420,17 +1302,13 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, ...@@ -1420,17 +1302,13 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
constexpr auto Activation = &gelu<fp32, fp32>; constexpr auto Activation = &gelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream); stream);
} }
void nvte_dswiglu_cast_transpose(const NVTETensor input, void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input,
const NVTETensor swiglu_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu_cast_transpose); NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1439,17 +1317,13 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, ...@@ -1439,17 +1317,13 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input,
constexpr auto Activation = &silu<fp32, fp32>; constexpr auto Activation = &silu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(swiglu_input),
*reinterpret_cast<const Tensor*>(swiglu_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream); stream);
} }
void nvte_dreglu_cast_transpose(const NVTETensor input, void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu_cast_transpose); NVTE_API_CALL(nvte_dreglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1458,17 +1332,13 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, ...@@ -1458,17 +1332,13 @@ void nvte_dreglu_cast_transpose(const NVTETensor input,
constexpr auto Activation = &relu<fp32, fp32>; constexpr auto Activation = &relu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream); stream);
} }
void nvte_dsreglu_cast_transpose(const NVTETensor input, void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu_cast_transpose); NVTE_API_CALL(nvte_dsreglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1477,17 +1347,13 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, ...@@ -1477,17 +1347,13 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input,
constexpr auto Activation = &srelu<fp32, fp32>; constexpr auto Activation = &srelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream); stream);
} }
void nvte_dqgeglu_cast_transpose(const NVTETensor input, void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input,
const NVTETensor gated_act_input, NVTETensor cast_output, NVTETensor transposed_output,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu_cast_transpose); NVTE_API_CALL(nvte_dqgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -1496,9 +1362,7 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, ...@@ -1496,9 +1362,7 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input,
constexpr auto Activation = &qgelu<fp32, fp32>; constexpr auto Activation = &qgelu<fp32, fp32>;
dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(gated_act_input),
*reinterpret_cast<const Tensor*>(gated_act_input), reinterpret_cast<Tensor *>(cast_output), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream); stream);
} }
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <transformer_engine/transpose.h>
#include <cfloat> #include <cfloat>
#include <iostream>
#include <vector> #include <vector>
#include "../utils.cuh"
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -40,21 +42,14 @@ struct MultiCastTransposeArgs { ...@@ -40,21 +42,14 @@ struct MultiCastTransposeArgs {
int row_length_list[kMaxTensorsPerKernel]; int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each // Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor // tensor
int block_range[kMaxTensorsPerKernel+1]; int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel // Number of tensors being processed by kernel
int num_tensors; int num_tensors;
}; };
template < template <int nvec_in, int nvec_out, bool aligned, typename CType, typename IType, typename OType>
int nvec_in, __global__ void __launch_bounds__(threads_per_block)
int nvec_out, multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
bool aligned,
typename CType,
typename IType,
typename OType>
__global__ void
__launch_bounds__(threads_per_block)
multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
using IVec = Vec<IType, nvec_in>; using IVec = Vec<IType, nvec_in>;
using OVecC = Vec<OType, nvec_in>; using OVecC = Vec<OType, nvec_in>;
using OVecT = Vec<OType, nvec_out>; using OVecT = Vec<OType, nvec_out>;
...@@ -79,7 +74,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -79,7 +74,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
// Find tensor corresponding to block // Find tensor corresponding to block
int tensor_id = 0; int tensor_id = 0;
while (args.block_range[tensor_id+1] <= bid) { while (args.block_range[tensor_id + 1] <= bid) {
++tensor_id; ++tensor_id;
} }
const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]); const IType* input = reinterpret_cast<const IType*>(args.input_list[tensor_id]);
...@@ -104,11 +99,11 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -104,11 +99,11 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
// type, and transposes in registers. // type, and transposes in registers.
OVecT local_output_t[nvec_in][n_iterations]; OVecT local_output_t[nvec_in][n_iterations];
CType local_amax = 0; CType local_amax = 0;
#pragma unroll #pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) { for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy; const int i1 = tidy + iter * bdimy;
const int j1 = tidx; const int j1 = tidx;
#pragma unroll #pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) { for (int i2 = 0; i2 < nvec_out; ++i2) {
const int row = tile_row + i1 * nvec_out + i2; const int row = tile_row + i1 * nvec_out + i2;
const int col = tile_col + j1 * nvec_in; const int col = tile_col + j1 * nvec_in;
...@@ -119,7 +114,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -119,7 +114,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
} else { } else {
local_input.clear(); local_input.clear();
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2]; local_input.data.elt[j2] = input[row * row_length + col + j2];
...@@ -127,7 +122,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -127,7 +122,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
} }
} }
} }
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
const CType x = CType(local_input.data.elt[j2]); const CType x = CType(local_input.data.elt[j2]);
const OType y = OType(scale * x); const OType y = OType(scale * x);
...@@ -140,7 +135,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -140,7 +135,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
local_output_c.store_to(&output_c[row * row_length + col]); local_output_c.store_to(&output_c[row * row_length + col]);
} else { } else {
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
output_c[row * row_length + col + j2] = local_output_c.data.elt[j2]; output_c[row * row_length + col + j2] = local_output_c.data.elt[j2];
...@@ -152,17 +147,17 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -152,17 +147,17 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
} }
// Copy transposed output from registers to global memory // Copy transposed output from registers to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (int j2 = 0; j2 < nvec_in; ++j2) { for (int j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) { for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidy + iter * bdimy; const int i1 = tidy + iter * bdimy;
const int j1 = tidx; const int j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter]; shared_output_t[j1][i1] = local_output_t[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int iter = 0; iter < n_iterations; ++iter) { for (int iter = 0; iter < n_iterations; ++iter) {
const int i1 = tidx; const int i1 = tidx;
const int j1 = tidy + iter * bdimy; const int j1 = tidy + iter * bdimy;
...@@ -172,7 +167,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -172,7 +167,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]);
} else { } else {
if (col < row_length) { if (col < row_length) {
#pragma unroll #pragma unroll
for (int i2 = 0; i2 < nvec_out; ++i2) { for (int i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) { if (row + i2 < num_rows) {
output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2];
...@@ -196,8 +191,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) { ...@@ -196,8 +191,7 @@ multi_cast_transpose_kernel(MultiCastTransposeArgs args) {
void multi_cast_transpose(const std::vector<Tensor*> input_list, void multi_cast_transpose(const std::vector<Tensor*> input_list,
std::vector<Tensor*> cast_output_list, std::vector<Tensor*> cast_output_list,
std::vector<Tensor*> transposed_output_list, std::vector<Tensor*> transposed_output_list, cudaStream_t stream) {
cudaStream_t stream) {
// Check that number of tensors is valid // Check that number of tensors is valid
NVTE_CHECK(cast_output_list.size() == input_list.size(), NVTE_CHECK(cast_output_list.size() == input_list.size(),
"Number of input and C output tensors must match"); "Number of input and C output tensors must match");
...@@ -218,15 +212,11 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -218,15 +212,11 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id)); CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id));
CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id)); CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id));
NVTE_CHECK(input.data.dtype == itype, NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match.");
"Input tensor types do not match."); NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match.");
NVTE_CHECK(cast_output.data.dtype == otype, NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match.");
"C output tensor types do not match.");
NVTE_CHECK(transposed_output.data.dtype == otype,
"T output tensor types do not match.");
NVTE_CHECK(input.data.shape.size() == 2, NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions.");
"Input tensor must have 2 dimensions.");
NVTE_CHECK(cast_output.data.shape == input.data.shape, NVTE_CHECK(cast_output.data.shape == input.data.shape,
"C output tensor shape does not match input tensor."); "C output tensor shape does not match input tensor.");
NVTE_CHECK(transposed_output.data.shape.size() == 2, NVTE_CHECK(transposed_output.data.shape.size() == 2,
...@@ -251,26 +241,27 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -251,26 +241,27 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
// Launch kernel if argument struct is full // Launch kernel if argument struct is full
if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) { if (kernel_args_aligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType); constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned); <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
kernel_args_aligned.num_tensors = 0; kernel_args_aligned.num_tensors = 0;
} }
if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType); constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; const int n_blocks =
kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned); <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
kernel_args_unaligned.num_tensors = 0; kernel_args_unaligned.num_tensors = 0;
} }
...@@ -283,8 +274,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -283,8 +274,8 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
const int num_tiles = num_tiles_m * num_tiles_n; const int num_tiles = num_tiles_m * num_tiles_n;
// Figure out whether to use aligned or unaligned kernel // Figure out whether to use aligned or unaligned kernel
const bool aligned = ((num_tiles_m * tile_dim_m == num_rows) const bool aligned =
&& (num_tiles_n * tile_dim_n == row_length)); ((num_tiles_m * tile_dim_m == num_rows) && (num_tiles_n * tile_dim_n == row_length));
auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned; auto& kernel_args = aligned ? kernel_args_aligned : kernel_args_unaligned;
// Add tensor to kernel argument struct // Add tensor to kernel argument struct
...@@ -296,53 +287,48 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, ...@@ -296,53 +287,48 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list,
kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr;
kernel_args.num_rows_list[pos] = num_rows; kernel_args.num_rows_list[pos] = num_rows;
kernel_args.row_length_list[pos] = row_length; kernel_args.row_length_list[pos] = row_length;
kernel_args.block_range[pos+1] = kernel_args.block_range[pos] + num_tiles; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles;
kernel_args.num_tensors++; kernel_args.num_tensors++;
} }
// Launch kernel // Launch kernel
if (kernel_args_aligned.num_tensors > 0) { if (kernel_args_aligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType); constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors]; const int n_blocks = kernel_args_aligned.block_range[kernel_args_aligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned); <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
if (kernel_args_unaligned.num_tensors > 0) { if (kernel_args_unaligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(itype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(otype, OutputType, itype, InputType,
constexpr int nvec_in = desired_load_size / sizeof(InputType); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
otype, OutputType, constexpr int nvec_in = desired_load_size / sizeof(InputType);
constexpr int nvec_out = desired_store_size / sizeof(OutputType); constexpr int nvec_out = desired_store_size / sizeof(OutputType);
const int n_blocks = kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors]; const int n_blocks =
kernel_args_unaligned.block_range[kernel_args_unaligned.num_tensors];
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned); <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_multi_cast_transpose(size_t num_tensors, void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list,
const NVTETensor* input_list, NVTETensor* cast_output_list, NVTETensor* transposed_output_list,
NVTETensor* cast_output_list,
NVTETensor* transposed_output_list,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_cast_transpose); NVTE_API_CALL(nvte_multi_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
std::vector<Tensor*> input_list_, std::vector<Tensor*> input_list_, cast_output_list_, transposed_output_list_;
cast_output_list_, transposed_output_list_;
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i]))); input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i])); cast_output_list_.push_back(reinterpret_cast<Tensor*>(cast_output_list[i]));
transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i])); transposed_output_list_.push_back(reinterpret_cast<Tensor*>(transposed_output_list[i]));
} }
multi_cast_transpose(input_list_, multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream);
cast_output_list_,
transposed_output_list_,
stream);
} }
...@@ -21,16 +21,11 @@ constexpr size_t block_size = __BLOCK_SIZE__; ...@@ -21,16 +21,11 @@ constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace } // namespace
__global__ void __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
__launch_bounds__(block_size) const IType* __restrict__ const input, const CType* __restrict__ const noop,
cast_transpose_optimized_kernel(const IType * __restrict__ const input, OType* __restrict__ const output_c, OType* __restrict__ const output_t,
const CType * __restrict__ const noop, const CType* __restrict__ const scale_ptr, CType* __restrict__ const amax_ptr,
OType * __restrict__ const output_c, const size_t row_length, const size_t num_rows) {
OType * __restrict__ const output_t,
const CType * __restrict__ const scale_ptr,
CType * __restrict__ const amax_ptr,
const size_t row_length,
const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
// Vectorized load/store sizes // Vectorized load/store sizes
...@@ -73,18 +68,18 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input, ...@@ -73,18 +68,18 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles, computes amax, // Note: Each thread loads num_iterations subtiles, computes amax,
// casts type, and transposes in registers. // casts type, and transposes in registers.
OVecT local_output_t[nvec_in][num_iterations]; OVecT local_output_t[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
IVec local_input; IVec local_input;
OVecC local_output_c; OVecC local_output_c;
local_input.load_from(&input[row * row_length + col]); local_input.load_from(&input[row * row_length + col]);
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
const CType in = static_cast<CType>(local_input.data.elt[j2]); const CType in = static_cast<CType>(local_input.data.elt[j2]);
const OType out = OType(in * scale); const OType out = OType(in * scale);
...@@ -98,17 +93,17 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input, ...@@ -98,17 +93,17 @@ cast_transpose_optimized_kernel(const IType * __restrict__ const input,
} }
// Copy from registers to shared memory to global memory // Copy from registers to shared memory to global memory
__shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output_t[j1][i1] = local_output_t[j2][iter]; shared_output_t[j1][i1] = local_output_t[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "utils.cuh"
#include "util/math.h" #include "util/math.h"
#include "utils.cuh"
using namespace transformer_engine; using namespace transformer_engine;
...@@ -32,7 +32,7 @@ using IVec2 = Vec<IType2, NVEC_IN>; ...@@ -32,7 +32,7 @@ using IVec2 = Vec<IType2, NVEC_IN>;
using OVec = Vec<OType, NVEC_OUT>; using OVec = Vec<OType, NVEC_OUT>;
using Param = CTDBiasDActParam<IType, IType2, OType, CType>; using Param = CTDBiasDActParam<IType, IType2, OType, CType>;
using OP = CType (*)(const CType, const Empty&); using OP = CType (*)(const CType, const Empty &);
constexpr OP Activation[] = { constexpr OP Activation[] = {
nullptr, // 0 nullptr, // 0
&dsigmoid<CType, CType>, // 1 &dsigmoid<CType, CType>, // 1
...@@ -45,14 +45,12 @@ constexpr OP Activation[] = { ...@@ -45,14 +45,12 @@ constexpr OP Activation[] = {
} // namespace } // namespace
inline __device__ void inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
OVec (&out_trans)[NVEC_IN], OVec (&out_trans)[NVEC_IN],
CVec &out_dbias, // NOLINT(*) CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile, typename OVec::type *output_cast_tile,
const size_t current_place, const size_t current_place,
const size_t stride, const size_t stride, const CType scale,
const CType scale,
CType &amax, // NOLINT(*) CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane) { const int dbias_shfl_src_lane) {
using OVecC = Vec<OType, NVEC_IN>; using OVecC = Vec<OType, NVEC_IN>;
...@@ -62,10 +60,10 @@ cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT], ...@@ -62,10 +60,10 @@ cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
step_dbias.clear(); step_dbias.clear();
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) { for (unsigned int i = 0; i < NVEC_OUT; ++i) {
OVecC out_cast; OVecC out_cast;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) { for (unsigned int j = 0; j < NVEC_IN; ++j) {
const CType tmp = in[i].data.elt[j]; const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
...@@ -81,7 +79,7 @@ cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT], ...@@ -81,7 +79,7 @@ cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
} }
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) { for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
...@@ -90,19 +88,16 @@ cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT], ...@@ -90,19 +88,16 @@ cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
} }
} }
__global__ void __global__ void __launch_bounds__(BLOCK_SIZE)
__launch_bounds__(BLOCK_SIZE) cast_transpose_fusion_kernel_optimized(const Param param, const size_t row_length,
cast_transpose_fusion_kernel_optimized(const Param param, const size_t num_rows, const size_t num_tiles) {
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
extern __shared__ char scratch[]; extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) const size_t tile_id =
+ warp_id / WARPS_PER_TILE; blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE) + warp_id / WARPS_PER_TILE;
if (tile_id >= num_tiles) { if (tile_id >= num_tiles) {
return; return;
} }
...@@ -110,24 +105,23 @@ cast_transpose_fusion_kernel_optimized(const Param param, ...@@ -110,24 +105,23 @@ cast_transpose_fusion_kernel_optimized(const Param param,
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset = (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) const size_t tile_offset =
* THREADS_PER_WARP; (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT) * THREADS_PER_WARP;
const size_t tile_offset_transp = (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) const size_t tile_offset_transp =
* THREADS_PER_WARP; (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN) * THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset; const IType *const my_input_tile = param.input + tile_offset;
const IType2 * const my_act_input_tile = param.act_input + tile_offset; const IType2 *const my_act_input_tile = param.act_input + tile_offset;
OType * const my_output_c_tile = param.output_c + tile_offset; OType *const my_output_c_tile = param.output_c + tile_offset;
OType * const my_output_t_tile = param.output_t + tile_offset_transp; OType *const my_output_t_tile = param.output_t + tile_offset_transp;
CType * const my_partial_dbias_tile = param.workspace CType *const my_partial_dbias_tile =
+ (tile_id_x * (NVEC_IN * THREADS_PER_WARP) param.workspace + (tile_id_x * (NVEC_IN * THREADS_PER_WARP) + tile_id_y * row_length);
+ tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) OVec *const my_scratch =
+ (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP) reinterpret_cast<OVec *>(scratch) +
* (THREADS_PER_WARP + 1); (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][NVEC_OUT]; IVec in[2][NVEC_OUT];
IVec2 act_in[2][NVEC_OUT]; IVec2 act_in[2][NVEC_OUT];
...@@ -140,8 +134,8 @@ cast_transpose_fusion_kernel_optimized(const Param param, ...@@ -140,8 +134,8 @@ cast_transpose_fusion_kernel_optimized(const Param param,
const size_t output_stride = num_rows / NVEC_OUT; const size_t output_stride = num_rows / NVEC_OUT;
size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride; size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT; size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) unsigned int my_place =
% THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
CType amax = 0.0f; CType amax = 0.0f;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
...@@ -151,20 +145,20 @@ cast_transpose_fusion_kernel_optimized(const Param param, ...@@ -151,20 +145,20 @@ cast_transpose_fusion_kernel_optimized(const Param param,
partial_dbias.clear(); partial_dbias.clear();
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) { for (unsigned int i = 0; i < NVEC_OUT; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i); act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i);
} }
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place; const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) { for (unsigned int j = 0; j < NVEC_OUT; ++j) {
const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j); const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j);
in[current_in][j].load_from(my_input_tile, ld_offset); in[current_in][j].load_from(my_input_tile, ld_offset);
...@@ -174,45 +168,42 @@ cast_transpose_fusion_kernel_optimized(const Param param, ...@@ -174,45 +168,42 @@ cast_transpose_fusion_kernel_optimized(const Param param,
} }
} }
CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*) CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) { for (unsigned int j = 0; j < NVEC_OUT; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < NVEC_IN; ++k) { for (unsigned int k = 0; k < NVEC_IN; ++k) {
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
in_cast_fp32[j].data.elt[k] = in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k]) static_cast<CType>(in[current_in ^ 1][j].data.elt[k]) *
* Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
} else { } else {
in_cast_fp32[j].data.elt[k] = in_cast_fp32[j].data.elt[k] = static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
} }
} }
} }
const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations) const int dbias_shfl_src_lane =
% THREADS_PER_WARP; (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias, my_output_c_tile,
my_output_c_tile, current_place, current_place, stride, scale, amax, dbias_shfl_src_lane);
stride, scale, amax, dbias_shfl_src_lane);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += NVEC_OUT * stride; current_stride += NVEC_OUT * stride;
current_row += NVEC_OUT; current_row += NVEC_OUT;
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < NVEC_IN; ++i) { for (unsigned int i = 0; i < NVEC_IN; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
- j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; THREADS_PER_WARP] = out_space[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) my_place =
% THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * NVEC_IN;
+ warp_id_in_tile * n_iterations * output_stride * NVEC_IN;
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place); current_stride + my_place);
...@@ -226,10 +217,10 @@ cast_transpose_fusion_kernel_optimized(const Param param, ...@@ -226,10 +217,10 @@ cast_transpose_fusion_kernel_optimized(const Param param,
my_dbias_scratch[threadIdx.x] = partial_dbias; my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads(); __syncthreads();
if (warp_id_in_tile == 0) { if (warp_id_in_tile == 0) {
#pragma unroll #pragma unroll
for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) { for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) { for (unsigned int j = 0; j < NVEC_IN; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j]; partial_dbias.data.elt[j] += tmp.data.elt[j];
} }
...@@ -239,7 +230,7 @@ cast_transpose_fusion_kernel_optimized(const Param param, ...@@ -239,7 +230,7 @@ cast_transpose_fusion_kernel_optimized(const Param param,
} }
// warp tile amax reduce // warp tile amax reduce
const CType max_block = reduce_max<BLOCK_SIZE/THREADS_PER_WARP>(amax, warp_id); const CType max_block = reduce_max<BLOCK_SIZE / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (param.amax != nullptr) { if (param.amax != nullptr) {
......
...@@ -19,12 +19,9 @@ constexpr size_t block_size = __BLOCK_SIZE__; ...@@ -19,12 +19,9 @@ constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace } // namespace
__global__ void __global__ void __launch_bounds__(block_size)
__launch_bounds__(block_size) transpose_optimized_kernel(const Type* __restrict__ const input, const float* const noop,
transpose_optimized_kernel(const Type * __restrict__ const input, Type* __restrict__ const output, const size_t row_length,
const float * const noop,
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) { const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
...@@ -63,17 +60,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input, ...@@ -63,17 +60,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles and transposes in // Note: Each thread loads num_iterations subtiles and transposes in
// registers. // registers.
OVec local_output[nvec_in][num_iterations]; OVec local_output[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
IVec local_input; IVec local_input;
local_input.load_from(&input[row * row_length + col]); local_input.load_from(&input[row * row_length + col]);
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
} }
...@@ -81,17 +78,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input, ...@@ -81,17 +78,17 @@ transpose_optimized_kernel(const Type * __restrict__ const input,
} }
// Copy from registers to shared memory to global memory // Copy from registers to shared memory to global memory
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output[j1][i1] = local_output[j2][iter]; shared_output[j1][i1] = local_output[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
......
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/cast_transpose_noop.h> #include <transformer_engine/cast_transpose_noop.h>
#include <transformer_engine/transpose.h> #include <transformer_engine/transpose.h>
#include <algorithm> #include <algorithm>
#include <cuda_runtime.h>
#include "../common.h" #include "../common.h"
#include "../util/rtc.h" #include "../util/rtc.h"
#include "../util/string.h" #include "../util/string.h"
...@@ -46,24 +45,18 @@ struct KernelConfig { ...@@ -46,24 +45,18 @@ struct KernelConfig {
/* Elements per L1 cache store */ /* Elements per L1 cache store */
size_t elements_per_store = 0; size_t elements_per_store = 0;
KernelConfig(size_t row_length, KernelConfig(size_t row_length, size_t num_rows, size_t type_size, size_t load_size_,
size_t num_rows,
size_t type_size,
size_t load_size_,
size_t store_size_) size_t store_size_)
: load_size{load_size_} : load_size{load_size_}, store_size{store_size_} {
, store_size{store_size_} {
// Check that tiles are correctly aligned // Check that tiles are correctly aligned
constexpr size_t cache_line_size = 128; constexpr size_t cache_line_size = 128;
if (load_size % type_size != 0 if (load_size % type_size != 0 || store_size % type_size != 0 ||
|| store_size % type_size != 0 cache_line_size % type_size != 0) {
|| cache_line_size % type_size != 0) {
return; return;
} }
const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size; const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size;
const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size; const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size;
valid = (row_length % row_tile_elements == 0 valid = (row_length % row_tile_elements == 0 && num_rows % col_tile_elements == 0);
&& num_rows % col_tile_elements == 0);
if (!valid) { if (!valid) {
return; return;
} }
...@@ -75,10 +68,8 @@ struct KernelConfig { ...@@ -75,10 +68,8 @@ struct KernelConfig {
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count())); static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) / type_size);
/ type_size); elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) / type_size);
elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size)
/ type_size);
} }
/* Compare by estimated cost */ /* Compare by estimated cost */
...@@ -93,8 +84,8 @@ struct KernelConfig { ...@@ -93,8 +84,8 @@ struct KernelConfig {
const auto &s2 = other.elements_per_store; const auto &s2 = other.elements_per_store;
const auto &p2 = other.active_sm_count; const auto &p2 = other.active_sm_count;
const auto scale = l1 * s1 * p1 * l2 * s2 * p2; const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
const auto cost1 = (scale/l1 + scale/s1) / p1; const auto cost1 = (scale / l1 + scale / s1) / p1;
const auto cost2 = (scale/l2 + scale/s2) / p2; const auto cost2 = (scale / l2 + scale / s2) / p2;
return cost1 < cost2; return cost1 < cost2;
} else { } else {
return this->valid && !other.valid; return this->valid && !other.valid;
...@@ -103,12 +94,9 @@ struct KernelConfig { ...@@ -103,12 +94,9 @@ struct KernelConfig {
}; };
template <size_t load_size, size_t store_size, typename Type> template <size_t load_size, size_t store_size, typename Type>
__global__ void __global__ void __launch_bounds__(block_size)
__launch_bounds__(block_size) transpose_general_kernel(const Type *__restrict__ const input, const fp32 *const noop,
transpose_general_kernel(const Type * __restrict__ const input, Type *__restrict__ const output, const size_t row_length,
const fp32 * const noop,
Type * __restrict__ const output,
const size_t row_length,
const size_t num_rows) { const size_t num_rows) {
if (noop != nullptr && noop[0] == 1.0f) return; if (noop != nullptr && noop[0] == 1.0f) return;
...@@ -147,25 +135,25 @@ transpose_general_kernel(const Type * __restrict__ const input, ...@@ -147,25 +135,25 @@ transpose_general_kernel(const Type * __restrict__ const input,
// Note: Each thread loads num_iterations subtiles and transposes in // Note: Each thread loads num_iterations subtiles and transposes in
// registers. // registers.
OVec local_output[nvec_in][num_iterations]; OVec local_output[nvec_in][num_iterations];
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
const size_t row = tile_row + i1 * nvec_out + i2; const size_t row = tile_row + i1 * nvec_out + i2;
const size_t col = tile_col + j1 * nvec_in; const size_t col = tile_col + j1 * nvec_in;
IVec local_input; IVec local_input;
local_input.clear(); local_input.clear();
if (row < num_rows) { if (row < num_rows) {
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
if (col + j2 < row_length) { if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2]; local_input.data.elt[j2] = input[row * row_length + col + j2];
} }
} }
} }
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2]; local_output[j2][iter].data.elt[i2] = local_input.data.elt[j2];
} }
...@@ -173,24 +161,24 @@ transpose_general_kernel(const Type * __restrict__ const input, ...@@ -173,24 +161,24 @@ transpose_general_kernel(const Type * __restrict__ const input,
} }
// Copy transposed output from registers to global memory // Copy transposed output from registers to global memory
__shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP+1]; __shared__ OVec shared_output[THREADS_PER_WARP][THREADS_PER_WARP + 1];
#pragma unroll #pragma unroll
for (size_t j2 = 0; j2 < nvec_in; ++j2) { for (size_t j2 = 0; j2 < nvec_in; ++j2) {
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidy + iter * bdimy; const size_t i1 = tidy + iter * bdimy;
const size_t j1 = tidx; const size_t j1 = tidx;
shared_output[j1][i1] = local_output[j2][iter]; shared_output[j1][i1] = local_output[j2][iter];
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (size_t iter = 0; iter < num_iterations; ++iter) { for (size_t iter = 0; iter < num_iterations; ++iter) {
const size_t i1 = tidx; const size_t i1 = tidx;
const size_t j1 = tidy + iter * bdimy; const size_t j1 = tidy + iter * bdimy;
const size_t row = tile_row + i1 * nvec_out; const size_t row = tile_row + i1 * nvec_out;
const size_t col = tile_col + j1 * nvec_in + j2; const size_t col = tile_col + j1 * nvec_in + j2;
if (col < row_length) { if (col < row_length) {
#pragma unroll #pragma unroll
for (size_t i2 = 0; i2 < nvec_out; ++i2) { for (size_t i2 = 0; i2 < nvec_out; ++i2) {
if (row + i2 < num_rows) { if (row + i2 < num_rows) {
output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2]; output[col * num_rows + row + i2] = shared_output[j1][i1].data.elt[i2];
...@@ -204,10 +192,7 @@ transpose_general_kernel(const Type * __restrict__ const input, ...@@ -204,10 +192,7 @@ transpose_general_kernel(const Type * __restrict__ const input,
} // namespace } // namespace
void transpose(const Tensor &input, void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) {
const Tensor &noop,
Tensor *output_,
cudaStream_t stream) {
Tensor &output = *output_; Tensor &output = *output_;
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions."); NVTE_CHECK(output.data.shape.size() == 2, "Output must have 2 dimensions.");
...@@ -219,64 +204,64 @@ void transpose(const Tensor &input, ...@@ -219,64 +204,64 @@ void transpose(const Tensor &input,
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == output.data.dtype, NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match.");
"Input and output type must match.");
// Number of elements in tensor // Number of elements in tensor
auto numel = [] (const Tensor &tensor) -> size_t { auto numel = [](const Tensor &tensor) -> size_t {
size_t acc = 1; size_t acc = 1;
for (const auto& dim : tensor.data.shape) { for (const auto &dim : tensor.data.shape) {
acc *= dim; acc *= dim;
} }
return acc; return acc;
}; };
if (noop.data.dptr != nullptr) { if (noop.data.dptr != nullptr) {
NVTE_CHECK(numel(noop) == 1, NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), ".");
"Expected 1 element, ",
"but found ", numel(noop), ".");
NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dtype == DType::kFloat32);
NVTE_CHECK(noop.data.dptr != nullptr); NVTE_CHECK(noop.data.dptr != nullptr);
} }
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
constexpr const char *type_name = TypeInfo<Type>::name; input.data.dtype, Type, constexpr const char *type_name = TypeInfo<Type>::name;
constexpr size_t type_size = sizeof(Type); constexpr size_t type_size = sizeof(Type);
// Choose between runtime-compiled or statically-compiled kernel // Choose between runtime-compiled or statically-compiled kernel
const bool aligned = (row_length % THREADS_PER_WARP == 0 const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0);
&& num_rows % THREADS_PER_WARP == 0);
if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel
// Pick kernel config // Pick kernel config
std::vector<KernelConfig> kernel_configs; std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16); kernel_configs.reserve(16);
auto add_config = [&](size_t load_size, size_t store_size) { auto add_config = [&](size_t load_size, size_t store_size) {
kernel_configs.emplace_back(row_length, num_rows, type_size, kernel_configs.emplace_back(row_length, num_rows, type_size, load_size, store_size);
load_size, store_size);
}; };
add_config(8, 8); add_config(8, 8);
add_config(4, 8); add_config(8, 4); add_config(4, 8);
add_config(8, 4);
add_config(4, 4); add_config(4, 4);
add_config(2, 8); add_config(8, 2); add_config(2, 8);
add_config(2, 4); add_config(4, 2); add_config(8, 2);
add_config(2, 4);
add_config(4, 2);
add_config(2, 2); add_config(2, 2);
add_config(1, 8); add_config(8, 1); add_config(1, 8);
add_config(1, 4); add_config(4, 1); add_config(8, 1);
add_config(1, 2); add_config(2, 1); add_config(1, 4);
add_config(4, 1);
add_config(1, 2);
add_config(2, 1);
add_config(1, 1); add_config(1, 1);
const auto &kernel_config = *std::min_element(kernel_configs.begin(), const auto &kernel_config = *std::min_element(kernel_configs.begin(), kernel_configs.end());
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config"); NVTE_CHECK(kernel_config.valid, "invalid kernel config");
const size_t load_size = kernel_config.load_size; const size_t load_size = kernel_config.load_size;
const size_t store_size = kernel_config.store_size; const size_t store_size = kernel_config.store_size;
const size_t num_blocks = kernel_config.num_blocks; const size_t num_blocks = kernel_config.num_blocks;
// Compile NVRTC kernel if needed and launch // Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance(); auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label = concat_strings("transpose" const std::string kernel_label = concat_strings(
",type=", type_name, "transpose"
",load_size=", load_size, ",type=",
",store_size=", store_size); type_name, ",load_size=", load_size, ",store_size=", store_size);
if (!rtc_manager.is_compiled(kernel_label)) { if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_transpose_cu; std::string code = string_code_transpose_rtc_transpose_cu;
code = regex_replace(code, "__TYPE__", type_name); code = regex_replace(code, "__TYPE__", type_name);
...@@ -284,56 +269,41 @@ void transpose(const Tensor &input, ...@@ -284,56 +269,41 @@ void transpose(const Tensor &input,
code = regex_replace(code, "__STORE_SIZE__", store_size); code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", block_size); code = regex_replace(code, "__BLOCK_SIZE__", block_size);
rtc_manager.compile(kernel_label, rtc_manager.compile(kernel_label, "transpose_optimized_kernel", code,
"transpose_optimized_kernel",
code,
"transformer_engine/common/transpose/rtc/transpose.cu"); "transformer_engine/common/transpose/rtc/transpose.cu");
} }
rtc_manager.launch(kernel_label, rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream,
num_blocks, block_size, 0, stream,
static_cast<const Type *>(input.data.dptr), static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr), static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type*>(output.data.dptr), static_cast<Type *>(output.data.dptr), row_length, num_rows);
row_length, num_rows);
} else { // Statically-compiled general kernel } else { // Statically-compiled general kernel
constexpr size_t load_size = 4; constexpr size_t load_size = 4;
constexpr size_t store_size = 4; constexpr size_t store_size = 4;
constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP; constexpr size_t row_tile_size = load_size / type_size * THREADS_PER_WARP;
constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP; constexpr size_t col_tile_size = store_size / type_size * THREADS_PER_WARP;
const int num_blocks = (DIVUP(row_length, row_tile_size) const int num_blocks = (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size));
* DIVUP(num_rows, col_tile_size)); transpose_general_kernel<load_size, store_size, Type>
transpose_general_kernel<load_size, store_size, Type><<<num_blocks, block_size, 0, stream>>>( <<<num_blocks, block_size, 0, stream>>>(static_cast<const Type *>(input.data.dptr),
static_cast<const Type *>(input.data.dptr),
static_cast<const fp32 *>(noop.data.dptr), static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr), static_cast<Type *>(output.data.dptr),
row_length, num_rows); row_length, num_rows);
} }); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_transpose(const NVTETensor input, void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose); NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor *>(input), noop, reinterpret_cast<Tensor *>(output),
noop,
reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
void nvte_transpose_with_noop(const NVTETensor input,
const NVTETensor noop,
NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop); NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor*>(input), transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
*reinterpret_cast<const Tensor*>(noop), reinterpret_cast<Tensor *>(output), stream);
reinterpret_cast<Tensor*>(output),
stream);
} }
...@@ -4,18 +4,19 @@ ...@@ -4,18 +4,19 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <transformer_engine/transpose.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <cfloat> #include <cfloat>
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include "../utils.cuh"
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
template <int nvec_in, int nvec_out, template <int nvec_in, int nvec_out, typename IVec, typename OVec, typename CVec, typename CType>
typename IVec, typename OVec, typename CVec, typename CType>
inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in], OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*) CVec &out_dbias, // NOLINT(*)
...@@ -24,7 +25,8 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out], ...@@ -24,7 +25,8 @@ inline __device__ void transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
using T = typename OVec::type; using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>; using OVecC = Vec<T, nvec_in>;
CVec step_dbias; step_dbias.clear(); CVec step_dbias;
step_dbias.clear();
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
...@@ -73,11 +75,8 @@ struct TDBiasParam { ...@@ -73,11 +75,8 @@ struct TDBiasParam {
} // namespace } // namespace
template <int nvec_in, int nvec_out, typename Param> template <int nvec_in, int nvec_out, typename Param>
__global__ void __global__ void __launch_bounds__(cast_transpose_num_threads)
__launch_bounds__(cast_transpose_num_threads) transpose_dbias_kernel(const Param param, const size_t row_length, const size_t num_rows,
transpose_dbias_kernel(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
using IType = typename Param::InputType; using IType = typename Param::InputType;
using OType = typename Param::OutputType; using OType = typename Param::OutputType;
...@@ -92,27 +91,24 @@ transpose_dbias_kernel(const Param param, ...@@ -92,27 +91,24 @@ transpose_dbias_kernel(const Param param,
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP); // const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id =
warp_id / n_warps_per_tile; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + const IType *const my_input_tile =
tile_id_y * row_length * nvec_out) * param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
THREADS_PER_WARP; OType *const my_output_t_tile =
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
tile_id_x * num_rows * nvec_in) * CType *const my_partial_dbias_tile =
THREADS_PER_WARP; param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) + OVec *const my_scratch =
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * reinterpret_cast<OVec *>(scratch) +
(THREADS_PER_WARP + 1); (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out]; IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
...@@ -123,9 +119,8 @@ transpose_dbias_kernel(const Param param, ...@@ -123,9 +119,8 @@ transpose_dbias_kernel(const Param param,
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - unsigned int my_place =
warp_id_in_tile * n_iterations) % (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
THREADS_PER_WARP;
const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1;
partial_dbias.clear(); partial_dbias.clear();
...@@ -147,10 +142,7 @@ transpose_dbias_kernel(const Param param, ...@@ -147,10 +142,7 @@ transpose_dbias_kernel(const Param param,
} }
OVec out_trans[nvec_in]; // NOLINT(*) OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs_partial_dbias( transpose_regs_partial_dbias(
in[current_in ^ 1], in[current_in ^ 1], out_trans, partial_dbias, scale_inv,
out_trans,
partial_dbias,
scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
#pragma unroll #pragma unroll
...@@ -164,14 +156,13 @@ transpose_dbias_kernel(const Param param, ...@@ -164,14 +156,13 @@ transpose_dbias_kernel(const Param param,
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; THREADS_PER_WARP] = out_space[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place); current_stride + my_place);
...@@ -199,12 +190,9 @@ transpose_dbias_kernel(const Param param, ...@@ -199,12 +190,9 @@ transpose_dbias_kernel(const Param param,
} }
template <int nvec_in, int nvec_out, typename Param> template <int nvec_in, int nvec_out, typename Param>
__global__ void __global__ void __launch_bounds__(cast_transpose_num_threads)
__launch_bounds__(cast_transpose_num_threads) transpose_dbias_kernel_notaligned(const Param param, const size_t row_length,
transpose_dbias_kernel_notaligned(const Param param, const size_t num_rows, const size_t num_tiles) {
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType; using IType = typename Param::InputType;
using OType = typename Param::OutputType; using OType = typename Param::OutputType;
using CType = typename Param::ComputeType; using CType = typename Param::ComputeType;
...@@ -216,38 +204,35 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -216,38 +204,35 @@ transpose_dbias_kernel_notaligned(const Param param,
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / const size_t num_tiles_x =
(nvec_in * THREADS_PER_WARP); (row_length + nvec_in * THREADS_PER_WARP - 1) / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id =
warp_id / n_warps_per_tile; blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + const IType *const my_input_tile =
tile_id_y * row_length * nvec_out) * param.input + (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * THREADS_PER_WARP;
THREADS_PER_WARP; OType *const my_output_t_tile =
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out + param.output_t + (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * THREADS_PER_WARP;
tile_id_x * num_rows * nvec_in) * CType *const my_partial_dbias_tile =
THREADS_PER_WARP; param.workspace + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_length =
: row_length_rest; row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_height =
: row_height_rest; row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP : row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) + OVec *const my_scratch =
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * reinterpret_cast<OVec *>(scratch) +
(THREADS_PER_WARP + 1); (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); CVec *const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out]; IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
...@@ -256,16 +241,14 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -256,16 +241,14 @@ transpose_dbias_kernel_notaligned(const Param param,
CVec partial_dbias; CVec partial_dbias;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - unsigned int my_place =
warp_id_in_tile * n_iterations) % (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
THREADS_PER_WARP;
const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1; const CType scale_inv = param.scale_inv != nullptr ? *param.scale_inv : 1;
partial_dbias.clear(); partial_dbias.clear();
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length && warp_id_in_tile * n_iterations < tile_height;
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) { if (valid_load) {
...@@ -280,8 +263,8 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -280,8 +263,8 @@ transpose_dbias_kernel_notaligned(const Param param,
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length && const bool valid_load =
warp_id_in_tile * n_iterations + i + 1 < tile_height; my_place_in < tile_length && warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) { if (valid_load) {
...@@ -294,10 +277,7 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -294,10 +277,7 @@ transpose_dbias_kernel_notaligned(const Param param,
} }
OVec out_trans[nvec_in]; // NOLINT(*) OVec out_trans[nvec_in]; // NOLINT(*)
transpose_regs_partial_dbias( transpose_regs_partial_dbias(
in[current_in ^ 1], in[current_in ^ 1], out_trans, partial_dbias, scale_inv,
out_trans,
partial_dbias,
scale_inv,
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP); (my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP);
#pragma unroll #pragma unroll
...@@ -311,14 +291,13 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -311,14 +291,13 @@ transpose_dbias_kernel_notaligned(const Param param,
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations) %
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; THREADS_PER_WARP] = out_space[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place =
THREADS_PER_WARP; (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride + warp_id_in_tile * n_iterations * output_stride * nvec_in;
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height; const bool valid_store = my_place < tile_height;
if (valid_store) { if (valid_store) {
...@@ -352,13 +331,10 @@ transpose_dbias_kernel_notaligned(const Param param, ...@@ -352,13 +331,10 @@ transpose_dbias_kernel_notaligned(const Param param,
constexpr size_t reduce_dbias_num_threads = 256; constexpr size_t reduce_dbias_num_threads = 256;
template<int nvec, typename ComputeType, typename OutputType> template <int nvec, typename ComputeType, typename OutputType>
__global__ void __global__ void __launch_bounds__(reduce_dbias_num_threads)
__launch_bounds__(reduce_dbias_num_threads) reduce_dbias_kernel(OutputType *const dbias_output, const ComputeType *const dbias_partial,
reduce_dbias_kernel(OutputType* const dbias_output, const int row_length, const int num_rows) {
const ComputeType* const dbias_partial,
const int row_length,
const int num_rows) {
using ComputeVec = Vec<ComputeType, nvec>; using ComputeVec = Vec<ComputeType, nvec>;
using OutputVec = Vec<OutputType, nvec>; using OutputVec = Vec<OutputType, nvec>;
...@@ -366,13 +342,14 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -366,13 +342,14 @@ reduce_dbias_kernel(OutputType* const dbias_output,
if (thread_id * nvec >= row_length) return; if (thread_id * nvec >= row_length) return;
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec; const ComputeType *const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec; OutputType *const thread_out_base = dbias_output + thread_id * nvec;
const int stride_in_vec = row_length / nvec; const int stride_in_vec = row_length / nvec;
ComputeVec ldg_vec; ComputeVec ldg_vec;
ComputeVec acc_vec; acc_vec.clear(); ComputeVec acc_vec;
acc_vec.clear();
for (int i = 0; i < num_rows; ++i) { for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec); ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll #pragma unroll
...@@ -390,8 +367,7 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -390,8 +367,7 @@ reduce_dbias_kernel(OutputType* const dbias_output,
} }
void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
Tensor* workspace, Tensor *workspace, const int nvec_out) {
const int nvec_out) {
const size_t row_length = input.data.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0]; const size_t num_rows = input.data.shape[0];
...@@ -405,37 +381,28 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ ...@@ -405,37 +381,28 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
} }
template <typename BiasType> template <typename BiasType>
void reduce_dbias(const Tensor &workspace, Tensor *dbias, void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_length,
const size_t row_length, const size_t num_rows, const int nvec_out, const size_t num_rows, const int nvec_out, cudaStream_t stream) {
cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64 constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType); constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(BiasType);
NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape."); NVTE_CHECK(row_length % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_row_length = row_length; const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows = DIVUP(num_rows, const size_t reduce_dbias_num_rows =
static_cast<size_t>(nvec_out * DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
THREADS_PER_WARP)); const size_t reduce_dbias_num_blocks =
const size_t reduce_dbias_num_blocks = DIVUP(row_length, DIVUP(row_length, reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, fp32, BiasType> reduce_dbias_kernel<reduce_dbias_nvec, fp32, BiasType>
<<<reduce_dbias_num_blocks, <<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>(
reduce_dbias_num_threads,
0,
stream>>>(
reinterpret_cast<BiasType *>(dbias->data.dptr), reinterpret_cast<BiasType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
} }
void fp8_transpose_dbias(const Tensor &input, void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias,
Tensor *transposed_output, Tensor *workspace, cudaStream_t stream) {
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
CheckInputTensor(input, "fp8_transpose_dbias_input"); CheckInputTensor(input, "fp8_transpose_dbias_input");
CheckOutputTensor(*transposed_output, "transposed_output"); CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias"); CheckOutputTensor(*dbias, "dbias");
...@@ -450,11 +417,12 @@ void fp8_transpose_dbias(const Tensor &input, ...@@ -450,11 +417,12 @@ void fp8_transpose_dbias(const Tensor &input,
NVTE_CHECK(transposed_output->data.dtype == input.data.dtype, NVTE_CHECK(transposed_output->data.dtype == input.data.dtype,
"T output must have the same type as input."); "T output must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{row_length}, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dbias->data.dtype, BiasType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, Type, dbias->data.dtype, BiasType,
constexpr int type_size = sizeof(Type); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, Type, constexpr int type_size = sizeof(Type);
constexpr int nvec_in = desired_load_size / type_size; constexpr int nvec_in = desired_load_size / type_size;
constexpr int nvec_out = desired_store_size / type_size; constexpr int nvec_out = desired_store_size / type_size;
...@@ -465,7 +433,8 @@ void fp8_transpose_dbias(const Tensor &input, ...@@ -465,7 +433,8 @@ void fp8_transpose_dbias(const Tensor &input,
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) * const size_t n_tiles =
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP)); DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
...@@ -473,58 +442,45 @@ void fp8_transpose_dbias(const Tensor &input, ...@@ -473,58 +442,45 @@ void fp8_transpose_dbias(const Tensor &input,
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0; num_rows % (nvec_out * THREADS_PER_WARP) == 0;
using ComputeType = fp32; using ComputeType = fp32; constexpr size_t shared_size_transpose =
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile * cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * (THREADS_PER_WARP + 1) * sizeof(Vec<Type, nvec_out>);
sizeof(Vec<Type, nvec_out>); constexpr size_t shared_size_dbias =
constexpr size_t shared_size_dbias = cast_transpose_num_threads * cast_transpose_num_threads * sizeof(Vec<ComputeType, nvec_in>);
sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias); static_assert(shared_size_transpose >= shared_size_dbias);
using Param = TDBiasParam<Type, Type, ComputeType>; using Param = TDBiasParam<Type, Type, ComputeType>; Param param;
Param param;
param.input = reinterpret_cast<const Type *>(input.data.dptr); param.input = reinterpret_cast<const Type *>(input.data.dptr);
param.output_t = reinterpret_cast<Type *>(transposed_output->data.dptr); param.output_t = reinterpret_cast<Type *>(transposed_output->data.dptr);
param.scale_inv = reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr); param.scale_inv =
reinterpret_cast<const ComputeType *>(transposed_output->scale_inv.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
100);
transpose_dbias_kernel<nvec_in, nvec_out, Param> transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks, <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
cast_transpose_num_threads, param, row_length, num_rows, n_tiles);
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else { } else {
cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout, 100);
100);
transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param> transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks, <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
cast_transpose_num_threads, param, row_length, num_rows, n_tiles);
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} }
reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out, stream); reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out,
); // NOLINT(*) stream);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_fp8_transpose_dbias(const NVTETensor input, void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_output,
NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_transpose_dbias); NVTE_API_CALL(nvte_fp8_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
fp8_transpose_dbias(*reinterpret_cast<const Tensor*>(input), fp8_transpose_dbias(
reinterpret_cast<Tensor*>(transposed_output), *reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor*>(dbias), reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
reinterpret_cast<Tensor*>(workspace),
stream);
} }
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
************************************************************************/ ************************************************************************/
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include "../common.h" #include "../common.h"
#include "../utils.cuh"
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -15,9 +16,7 @@ namespace detail { ...@@ -15,9 +16,7 @@ namespace detail {
struct Empty {}; struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) { __device__ inline fp32 identity(fp32 value, const Empty &) { return value; }
return value;
}
struct DequantizeParam { struct DequantizeParam {
const fp32 *scale_inv; const fp32 *scale_inv;
...@@ -29,83 +28,63 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) ...@@ -29,83 +28,63 @@ __device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param)
} // namespace detail } // namespace detail
void fp8_quantize(const Tensor &input, void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "cast_input"); CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output"); CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(!is_fp8_dtype(input.data.dtype), NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision.");
"Input must be in higher precision.");
NVTE_CHECK(is_fp8_dtype(output->data.dtype), NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type.");
"Output must have FP8 type.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(IType); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>( VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr), N, {},
N, stream);); // NOLINT(*)
{},
stream);
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
void fp8_dequantize(const Tensor &input, void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "cast_input"); CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output"); CheckOutputTensor(*output, "cast_output");
NVTE_CHECK(is_fp8_dtype(input.data.dtype), NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
"Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
"Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->data.dtype, OType, input.data.dtype, IType,
constexpr int nvec = 32 / sizeof(OType); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType);
detail::DequantizeParam p; detail::DequantizeParam p;
p.scale_inv = reinterpret_cast<const fp32*>(input.scale_inv.dptr); p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>( VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
reinterpret_cast<const IType*>(input.data.dptr), reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr), reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, N, p,
nullptr, stream);); // NOLINT(*)
nullptr,
N,
p,
stream);
); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
} // namespace transformer_engine } // namespace transformer_engine
void nvte_fp8_quantize(const NVTETensor input, void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_quantize); NVTE_API_CALL(nvte_fp8_quantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_quantize(*reinterpret_cast<const Tensor*>(input), fp8_quantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_fp8_dequantize(const NVTETensor input, void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_dequantize); NVTE_API_CALL(nvte_fp8_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
fp8_dequantize(*reinterpret_cast<const Tensor*>(input), fp8_dequantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
reinterpret_cast<Tensor*>(output),
stream); stream);
} }
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
************************************************************************/ ************************************************************************/
#include <dlfcn.h> #include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include "../common.h" #include "../common.h"
...@@ -40,27 +41,21 @@ class Library { ...@@ -40,27 +41,21 @@ class Library {
#endif // _WIN32 or _WIN64 or __WINDOW__ #endif // _WIN32 or _WIN64 or __WINDOW__
} }
Library(const Library&) = delete; // move-only Library(const Library &) = delete; // move-only
Library(Library&& other) noexcept { Library(Library &&other) noexcept { swap(*this, other); }
swap(*this, other);
}
Library& operator=(Library other) noexcept { Library &operator=(Library other) noexcept {
// Copy-and-swap idiom // Copy-and-swap idiom
swap(*this, other); swap(*this, other);
return *this; return *this;
} }
friend void swap(Library& first, Library& second) noexcept; friend void swap(Library &first, Library &second) noexcept;
void *get() noexcept { void *get() noexcept { return handle_; }
return handle_;
}
const void *get() const noexcept { const void *get() const noexcept { return handle_; }
return handle_;
}
/*! \brief Get pointer corresponding to symbol in shared library */ /*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) { void *get_symbol(const char *symbol) {
...@@ -78,13 +73,13 @@ class Library { ...@@ -78,13 +73,13 @@ class Library {
void *handle_ = nullptr; void *handle_ = nullptr;
}; };
void swap(Library& first, Library& second) noexcept { void swap(Library &first, Library &second) noexcept {
using std::swap; using std::swap;
swap(first.handle_, second.handle_); swap(first.handle_, second.handle_);
} }
/*! \brief Lazily-initialized shared library for CUDA driver */ /*! \brief Lazily-initialized shared library for CUDA driver */
Library& cuda_driver_lib() { Library &cuda_driver_lib() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
constexpr char lib_name[] = "nvcuda.dll"; constexpr char lib_name[] = "nvcuda.dll";
#else #else
...@@ -98,9 +93,7 @@ Library& cuda_driver_lib() { ...@@ -98,9 +93,7 @@ Library& cuda_driver_lib() {
namespace cuda_driver { namespace cuda_driver {
void *get_symbol(const char *symbol) { void *get_symbol(const char *symbol) { return cuda_driver_lib().get_symbol(symbol); }
return cuda_driver_lib().get_symbol(symbol);
}
} // namespace cuda_driver } // namespace cuda_driver
......
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
#include <string>
#include <cuda.h> #include <cuda.h>
#include <string>
#include "../common.h" #include "../common.h"
#include "../util/string.h" #include "../util/string.h"
...@@ -35,7 +35,7 @@ void *get_symbol(const char *symbol); ...@@ -35,7 +35,7 @@ void *get_symbol(const char *symbol);
template <typename... ArgTs> template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) { inline CUresult call(const char *symbol, ArgTs... args) {
using FuncT = CUresult(ArgTs...); using FuncT = CUresult(ArgTs...);
FuncT *func = reinterpret_cast<FuncT*>(get_symbol(symbol)); FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));
return (*func)(args...); return (*func)(args...);
} }
...@@ -48,9 +48,7 @@ inline CUresult call(const char *symbol, ArgTs... args) { ...@@ -48,9 +48,7 @@ inline CUresult call(const char *symbol, ArgTs... args) {
const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \ const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \
if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \ if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \
const char *desc_NVTE_CHECK_CUDA_DRIVER; \ const char *desc_NVTE_CHECK_CUDA_DRIVER; \
::transformer_engine::cuda_driver::call( \ ::transformer_engine::cuda_driver::call("cuGetErrorString", status_NVTE_CHECK_CUDA_DRIVER, \
"cuGetErrorString", \
status_NVTE_CHECK_CUDA_DRIVER, \
&desc_NVTE_CHECK_CUDA_DRIVER); \ &desc_NVTE_CHECK_CUDA_DRIVER); \
NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \ NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \
} \ } \
...@@ -58,8 +56,7 @@ inline CUresult call(const char *symbol, ArgTs... args) { ...@@ -58,8 +56,7 @@ inline CUresult call(const char *symbol, ArgTs... args) {
#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \ #define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \
do { \ do { \
NVTE_CHECK_CUDA_DRIVER( \ NVTE_CHECK_CUDA_DRIVER(::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \
::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \
} while (false) } while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../util/cuda_runtime.h"
#include <filesystem> #include <filesystem>
#include <mutex> #include <mutex>
#include "../common.h" #include "../common.h"
#include "../util/cuda_driver.h" #include "../util/cuda_driver.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h" #include "../util/system.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -24,7 +25,7 @@ namespace { ...@@ -24,7 +25,7 @@ namespace {
} // namespace } // namespace
int num_devices() { int num_devices() {
auto query_num_devices = [] () -> int { auto query_num_devices = []() -> int {
int count; int count;
NVTE_CHECK_CUDA(cudaGetDeviceCount(&count)); NVTE_CHECK_CUDA(cudaGetDeviceCount(&count));
return count; return count;
...@@ -54,10 +55,10 @@ int sm_arch(int device_id) { ...@@ -54,10 +55,10 @@ int sm_arch(int device_id) {
device_id = current_device(); device_id = current_device();
} }
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () { auto init = [&]() {
cudaDeviceProp prop; cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id)); NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = 10*prop.major + prop.minor; cache[device_id] = 10 * prop.major + prop.minor;
}; };
std::call_once(flags[device_id], init); std::call_once(flags[device_id], init);
return cache[device_id]; return cache[device_id];
...@@ -70,7 +71,7 @@ int sm_count(int device_id) { ...@@ -70,7 +71,7 @@ int sm_count(int device_id) {
device_id = current_device(); device_id = current_device();
} }
NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID");
auto init = [&] () { auto init = [&]() {
cudaDeviceProp prop; cudaDeviceProp prop;
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id)); NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, device_id));
cache[device_id] = prop.multiProcessorCount; cache[device_id] = prop.multiProcessorCount;
...@@ -90,8 +91,7 @@ const std::string &include_directory(bool required) { ...@@ -90,8 +91,7 @@ const std::string &include_directory(bool required) {
if (need_to_check_env) { if (need_to_check_env) {
// Search for CUDA headers in common paths // Search for CUDA headers in common paths
using Path = std::filesystem::path; using Path = std::filesystem::path;
std::vector<std::pair<std::string, Path>> search_paths = { std::vector<std::pair<std::string, Path>> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""},
{"NVTE_CUDA_INCLUDE_DIR", ""},
{"CUDA_HOME", ""}, {"CUDA_HOME", ""},
{"CUDA_DIR", ""}, {"CUDA_DIR", ""},
{"", string_path_cuda_include}, {"", string_path_cuda_include},
...@@ -131,7 +131,8 @@ const std::string &include_directory(bool required) { ...@@ -131,7 +131,8 @@ const std::string &include_directory(bool required) {
message += p; message += p;
} }
} }
message += (". " message +=
(". "
"Specify path to CUDA Toolkit headers " "Specify path to CUDA Toolkit headers "
"with NVTE_CUDA_INCLUDE_DIR " "with NVTE_CUDA_INCLUDE_DIR "
"or disable NVRTC support with NVTE_DISABLE_NVRTC=1."); "or disable NVRTC support with NVTE_DISABLE_NVRTC=1.");
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_RUNTIME_H_
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <string> #include <string>
namespace transformer_engine { namespace transformer_engine {
......
...@@ -7,21 +7,19 @@ ...@@ -7,21 +7,19 @@
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
#include <stdexcept>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cudnn.h> #include <cudnn.h>
#include <nvrtc.h> #include <nvrtc.h>
#include <stdexcept>
#include "../util/string.h" #include "../util/string.h"
#define NVTE_ERROR(...) \ #define NVTE_ERROR(...) \
do { \ do { \
throw ::std::runtime_error( \ throw ::std::runtime_error(::transformer_engine::concat_strings( \
::transformer_engine::concat_strings( \ __FILE__ ":", __LINE__, " in function ", __func__, ": ", \
__FILE__ ":", __LINE__, \
" in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__))); \ ::transformer_engine::concat_strings(__VA_ARGS__))); \
} while (false) } while (false)
...@@ -37,8 +35,7 @@ ...@@ -37,8 +35,7 @@
do { \ do { \
const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ const cudaError_t status_NVTE_CHECK_CUDA = (expr); \
if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ if (status_NVTE_CHECK_CUDA != cudaSuccess) { \
NVTE_ERROR("CUDA Error: ", \ NVTE_ERROR("CUDA Error: ", cudaGetErrorString(status_NVTE_CHECK_CUDA)); \
cudaGetErrorString(status_NVTE_CHECK_CUDA)); \
} \ } \
} while (false) } while (false)
...@@ -46,8 +43,7 @@ ...@@ -46,8 +43,7 @@
do { \ do { \
const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLAS Error: ", \ NVTE_ERROR("cuBLAS Error: ", cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \
} \ } \
} while (false) } while (false)
...@@ -55,8 +51,7 @@ ...@@ -55,8 +51,7 @@
do { \ do { \
const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \ const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \
if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \ if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \
NVTE_ERROR("cuDNN Error: ", \ NVTE_ERROR("cuDNN Error: ", cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \
cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \
". " \ ". " \
"For more information, enable cuDNN error logging " \ "For more information, enable cuDNN error logging " \
"by setting CUDNN_LOGERR_DBG=1 and " \ "by setting CUDNN_LOGERR_DBG=1 and " \
...@@ -68,8 +63,7 @@ ...@@ -68,8 +63,7 @@
do { \ do { \
const auto error = (expr); \ const auto error = (expr); \
if (error.is_bad()) { \ if (error.is_bad()) { \
NVTE_ERROR("cuDNN Error: ", \ NVTE_ERROR("cuDNN Error: ", error.err_msg, \
error.err_msg, \
". " \ ". " \
"For more information, enable cuDNN error logging " \ "For more information, enable cuDNN error logging " \
"by setting CUDNN_LOGERR_DBG=1 and " \ "by setting CUDNN_LOGERR_DBG=1 and " \
...@@ -81,8 +75,7 @@ ...@@ -81,8 +75,7 @@
do { \ do { \
const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \ const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \
if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \ if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \
NVTE_ERROR("NVRTC Error: ", \ NVTE_ERROR("NVRTC Error: ", nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \
nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \
} \ } \
} while (false) } while (false)
......
...@@ -21,8 +21,7 @@ template <typename OType, typename IType> ...@@ -21,8 +21,7 @@ template <typename OType, typename IType>
__device__ inline OType dgelu(const IType val, const Empty&) { __device__ inline OType dgelu(const IType val, const Empty&) {
const float cval = val; const float cval = val;
const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval)); const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out); 0.5f * (1.f + tanh_out);
} }
...@@ -48,8 +47,7 @@ __device__ inline OType qgelu(const IType val, const Empty& e) { ...@@ -48,8 +47,7 @@ __device__ inline OType qgelu(const IType val, const Empty& e) {
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) { __device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) + return cval * dsigmoid<float, float>(1.702f * cval, e) + sigmoid<float, float>(1.702f * cval, e);
sigmoid<float, float>(1.702f * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
...@@ -65,22 +63,22 @@ __device__ inline OType dsilu(const IType val, const Empty& e) { ...@@ -65,22 +63,22 @@ __device__ inline OType dsilu(const IType val, const Empty& e) {
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty &) { __device__ inline OType relu(IType value, const Empty&) {
return fmaxf(value, 0.f); return fmaxf(value, 0.f);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType drelu(IType value, const Empty &) { __device__ inline OType drelu(IType value, const Empty&) {
return value > 0.f ? 1.f : 0.f; return value > 0.f ? 1.f : 0.f;
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType srelu(IType value, const Empty &) { __device__ inline OType srelu(IType value, const Empty&) {
return value > 0 ? value * value : 0.f; return value > 0 ? value * value : 0.f;
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dsrelu(IType value, const Empty &) { __device__ inline OType dsrelu(IType value, const Empty&) {
return fmaxf(2.f * value, 0.f); return fmaxf(2.f * value, 0.f);
} }
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../util/rtc.h"
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
#include <utility> #include <utility>
...@@ -13,8 +15,6 @@ ...@@ -13,8 +15,6 @@
#include "../util/string.h" #include "../util/string.h"
#include "../util/system.h" #include "../util/system.h"
#include "../util/rtc.h"
namespace transformer_engine { namespace transformer_engine {
namespace rtc { namespace rtc {
...@@ -22,8 +22,8 @@ namespace rtc { ...@@ -22,8 +22,8 @@ namespace rtc {
namespace { namespace {
// Strings with headers for RTC kernels // Strings with headers for RTC kernels
#include "string_code_utils_cuh.h"
#include "string_code_util_math_h.h" #include "string_code_util_math_h.h"
#include "string_code_utils_cuh.h"
/*! \brief Latest compute capability that NVRTC supports /*! \brief Latest compute capability that NVRTC supports
* *
...@@ -56,29 +56,25 @@ bool is_enabled() { ...@@ -56,29 +56,25 @@ bool is_enabled() {
} }
Kernel::Kernel(std::string mangled_name, std::string compiled_code) Kernel::Kernel(std::string mangled_name, std::string compiled_code)
: mangled_name_{std::move(mangled_name)} : mangled_name_{std::move(mangled_name)},
, compiled_code_{std::move(compiled_code)} compiled_code_{std::move(compiled_code)},
, modules_(cuda::num_devices(), null_module) modules_(cuda::num_devices(), null_module),
, functions_(cuda::num_devices(), null_function) functions_(cuda::num_devices(), null_function),
, init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} { init_flags_{std::make_unique<std::vector<std::once_flag>>(cuda::num_devices())} {}
}
Kernel::~Kernel() { Kernel::~Kernel() {
for (int device_id=0; device_id<static_cast<int>(modules_.size()); ++device_id) { for (int device_id = 0; device_id < static_cast<int>(modules_.size()); ++device_id) {
// Unload CUDA modules if needed // Unload CUDA modules if needed
if (modules_[device_id] != null_module) { if (modules_[device_id] != null_module) {
CUdevice device; CUdevice device;
CUcontext context; CUcontext context;
if (cuda_driver::call("cuDeviceGet", &device, device_id) if (cuda_driver::call("cuDeviceGet", &device, device_id) != CUDA_SUCCESS) {
!= CUDA_SUCCESS) {
continue; continue;
} }
if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) if (cuda_driver::call("cuDevicePrimaryCtxRetain", &context, device) != CUDA_SUCCESS) {
!= CUDA_SUCCESS) {
continue; continue;
} }
if (cuda_driver::call("cuCtxSetCurrent", context) if (cuda_driver::call("cuCtxSetCurrent", context) != CUDA_SUCCESS) {
!= CUDA_SUCCESS) {
continue; continue;
} }
cuda_driver::call("cuModuleUnload", modules_[device_id]); cuda_driver::call("cuModuleUnload", modules_[device_id]);
...@@ -87,9 +83,7 @@ Kernel::~Kernel() { ...@@ -87,9 +83,7 @@ Kernel::~Kernel() {
} }
} }
Kernel::Kernel(Kernel&& other) noexcept { Kernel::Kernel(Kernel&& other) noexcept { swap(*this, other); }
swap(*this, other);
}
Kernel& Kernel::operator=(Kernel other) noexcept { Kernel& Kernel::operator=(Kernel other) noexcept {
// Copy-and-swap idiom // Copy-and-swap idiom
...@@ -108,7 +102,7 @@ void swap(Kernel& first, Kernel& second) noexcept { ...@@ -108,7 +102,7 @@ void swap(Kernel& first, Kernel& second) noexcept {
CUfunction Kernel::get_function(int device_id) { CUfunction Kernel::get_function(int device_id) {
// Load kernel on device if needed // Load kernel on device if needed
auto load_on_device = [&] () { auto load_on_device = [&]() {
// Set driver context to proper device // Set driver context to proper device
CUdevice device; CUdevice device;
CUcontext context; CUcontext context;
...@@ -117,15 +111,11 @@ CUfunction Kernel::get_function(int device_id) { ...@@ -117,15 +111,11 @@ CUfunction Kernel::get_function(int device_id) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
// Load function into driver context // Load function into driver context
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, &modules_[device_id], compiled_code_.c_str(),
&modules_[device_id],
compiled_code_.c_str(),
0, // numOptions 0, // numOptions
nullptr, // options nullptr, // options
nullptr); // optionValues nullptr); // optionValues
NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction, NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleGetFunction, &functions_[device_id], modules_[device_id],
&functions_[device_id],
modules_[device_id],
mangled_name_.c_str()); mangled_name_.c_str());
// Reset driver context // Reset driver context
...@@ -147,10 +137,8 @@ KernelManager& KernelManager::instance() { ...@@ -147,10 +137,8 @@ KernelManager& KernelManager::instance() {
return instance_; return instance_;
} }
void KernelManager::compile(const std::string &kernel_label, void KernelManager::compile(const std::string& kernel_label, const std::string& kernel_name,
const std::string &kernel_name, const std::string& code, const std::string& filename) {
const std::string &code,
const std::string &filename) {
std::lock_guard<std::mutex> lock_guard_(lock_); std::lock_guard<std::mutex> lock_guard_(lock_);
// Choose whether to compile to PTX or cubin // Choose whether to compile to PTX or cubin
...@@ -181,20 +169,14 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -181,20 +169,14 @@ void KernelManager::compile(const std::string &kernel_label,
constexpr int num_headers = 2; constexpr int num_headers = 2;
constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h}; constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"}; constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, code.c_str(), filename.c_str(), num_headers,
code.c_str(), headers, include_names));
filename.c_str(),
num_headers,
headers,
include_names));
NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str())); NVTE_CHECK_NVRTC(nvrtcAddNameExpression(program, kernel_name.c_str()));
const nvrtcResult compile_result = nvrtcCompileProgram(program, const nvrtcResult compile_result =
opts_ptrs.size(), nvrtcCompileProgram(program, opts_ptrs.size(), opts_ptrs.data());
opts_ptrs.data());
if (compile_result != NVRTC_SUCCESS) { if (compile_result != NVRTC_SUCCESS) {
// Display log if compilation failed // Display log if compilation failed
std::string log = concat_strings("NVRTC compilation log for ", std::string log = concat_strings("NVRTC compilation log for ", filename, ":\n");
filename, ":\n");
const size_t log_offset = log.size(); const size_t log_offset = log.size();
size_t log_size; size_t log_size;
NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size)); NVTE_CHECK_NVRTC(nvrtcGetProgramLogSize(program, &log_size));
...@@ -206,10 +188,8 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -206,10 +188,8 @@ void KernelManager::compile(const std::string &kernel_label,
} }
// Get mangled function name // Get mangled function name
const char *mangled_name; const char* mangled_name;
NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, NVTE_CHECK_NVRTC(nvrtcGetLoweredName(program, kernel_name.c_str(), &mangled_name));
kernel_name.c_str(),
&mangled_name));
// Get compiled code // Get compiled code
std::string compiled_code; std::string compiled_code;
...@@ -234,20 +214,19 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -234,20 +214,19 @@ void KernelManager::compile(const std::string &kernel_label,
NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program)); NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program));
} }
void KernelManager::set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config) { void KernelManager::set_cache_config(const std::string& kernel_label, CUfunc_cache cache_config) {
const int device_id = cuda::current_device(); const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id); const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0, NVTE_CHECK(kernel_cache_.count(key) > 0, "Attempted to configure RTC kernel before compilation");
"Attempted to configure RTC kernel before compilation");
kernel_cache_.at(key).set_function_cache_config(device_id, cache_config); kernel_cache_.at(key).set_function_cache_config(device_id, cache_config);
} }
bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const { bool KernelManager::is_compiled(const std::string& kernel_label, int device_id) const {
const auto key = get_kernel_cache_key(kernel_label, device_id); const auto key = get_kernel_cache_key(kernel_label, device_id);
return kernel_cache_.count(key) > 0; return kernel_cache_.count(key) > 0;
} }
std::string KernelManager::get_kernel_cache_key(const std::string &kernel_label, std::string KernelManager::get_kernel_cache_key(const std::string& kernel_label,
int device_id) const { int device_id) const {
return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label); return concat_strings("sm=", cuda::sm_arch(device_id), ",", kernel_label);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment