Unverified Commit 6a2dd785 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Added JIT-compiled fused cast transpose kernels (#903)



* Merged CT+dbias+dact into a single template
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Moved gated activations ifrom the cast_transpose_fused ito a sseparate cpp file
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Update transformer_engine/common/transpose/cast_transpose_fusion.cu
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Update transformer_engine/common/transpose/cast_transpose_fusion.cu
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Reverted the change with the file split
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Implemented JIT compiled kernels
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Replaced aligned statically compiled kernels with JIT kernels. Added support of various activations functions for JIT kernels. Cleaned up the code per the code review
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 793a54bf
...@@ -67,12 +67,18 @@ endfunction() ...@@ -67,12 +67,18 @@ endfunction()
list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path) list(GET CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES 0 cuda_include_path)
make_string_header("${cuda_include_path}" make_string_header("${cuda_include_path}"
string_path_cuda_include) string_path_cuda_include)
make_string_header_from_file(transpose/rtc/cast_transpose_fusion.cu
string_code_transpose_rtc_cast_transpose_fusion_cu)
make_string_header_from_file(transpose/rtc/cast_transpose.cu make_string_header_from_file(transpose/rtc/cast_transpose.cu
string_code_transpose_rtc_cast_transpose_cu) string_code_transpose_rtc_cast_transpose_cu)
make_string_header_from_file(transpose/rtc/transpose.cu make_string_header_from_file(transpose/rtc/transpose.cu
string_code_transpose_rtc_transpose_cu) string_code_transpose_rtc_transpose_cu)
make_string_header_from_file(utils.cuh make_string_header_from_file(utils.cuh
string_code_utils_cuh) string_code_utils_cuh)
make_string_header_from_file(util/math.h
string_code_util_math_h)
target_include_directories(transformer_engine PRIVATE target_include_directories(transformer_engine PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers") "${CMAKE_CURRENT_BINARY_DIR}/string_headers")
......
...@@ -10,65 +10,154 @@ ...@@ -10,65 +10,154 @@
#include <iostream> #include <iostream>
#include <type_traits> #include <type_traits>
#include "../utils.cuh" #include "../utils.cuh"
#include "../util/rtc.h"
#include "../util/string.h"
#include "../common.h" #include "../common.h"
#include "../util/math.h" #include "../util/math.h"
namespace transformer_engine { namespace transformer_engine {
namespace {
// String with RTC kernel implementation
#include "string_code_transpose_rtc_cast_transpose_fusion_cu.h"
// STUFF TO TUNE // STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 8; constexpr size_t n_warps_per_tile = 8;
constexpr unsigned int max_threads_per_block = 256; constexpr size_t desired_load_size = 8;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block); constexpr size_t desired_store_size = 8;
constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP; constexpr size_t desired_load_size_dact = 4; // dAct fusion kernels use more registers
constexpr size_t desired_store_size_dact = 4;
constexpr size_t threads_per_warp = static_cast<size_t>(THREADS_PER_WARP);
constexpr size_t max_threads_per_block = 256;
constexpr size_t reduce_dbias_num_threads = 256; constexpr size_t reduce_dbias_num_threads = 256;
constexpr size_t cast_transpose_num_threads = n_warps_per_tile * threads_per_warp;
constexpr size_t n_warps_per_block = cast_transpose_num_threads / threads_per_warp;
static_assert(cast_transpose_num_threads <= max_threads_per_block);
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
size_t load_size = 0; // Vector load size
size_t store_size = 0; // Vector store size to transposed output
bool valid = false; // Whether config is valid
bool is_dact = false; // Whether dact is used
size_t num_blocks = 0; // Number of CUDA blocks
size_t active_sm_count = 0; // Number of active SMs
size_t elements_per_load = 0; // Elements per L1 cache load
size_t elements_per_load_dact = 0; // Elements per L1 cache load dact
size_t elements_per_store_c = 0; // Elements per L1 cache store to cast output
size_t elements_per_store_t = 0; // Elements per L1 cache store to transposed output
KernelConfig(size_t row_length,
size_t num_rows,
size_t itype_size,
size_t itype2_size,
size_t otype_size,
size_t load_size_,
size_t store_size_,
bool is_dact_)
: load_size{load_size_}
, store_size{store_size_}
, is_dact{is_dact_} {
if (is_dact) {
if (load_size > desired_load_size_dact || store_size > desired_store_size_dact) {
return;
}
}
template <bool full_tile, int nvec_in, int nvec_out, typename IVec, typename OVec, typename CType> // Check that tiles are correctly aligned
inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], constexpr size_t cache_line_size = 128;
if (load_size % itype_size != 0
|| store_size % otype_size != 0
|| cache_line_size % itype_size != 0
|| cache_line_size % otype_size != 0) {
return;
}
/* row_tile_elements */
const size_t tile_size_x = (load_size * THREADS_PER_WARP) / itype_size;
/* col_tile_elements */
const size_t tile_size_y = (store_size * THREADS_PER_WARP) / otype_size;
const size_t num_tiles_x = row_length / tile_size_x;
const size_t num_tiles_y = num_rows / tile_size_y;
valid = (row_length % tile_size_x == 0 && num_rows % tile_size_y == 0);
if (!valid) {
return;
}
// Number of CUDA blocks
num_blocks = num_tiles_x * num_tiles_y;
// Parameters for performance model
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * n_warps_per_tile, warps_per_sm),
static_cast<size_t>(cuda::sm_count()));
elements_per_load = (std::min(cache_line_size, tile_size_x * itype_size)
/ itype_size);
elements_per_load_dact = (std::min(cache_line_size, tile_size_x * itype2_size)
/ itype2_size);
elements_per_store_c = (std::min(cache_line_size, tile_size_x * otype_size)
/ otype_size);
elements_per_store_t = (std::min(cache_line_size, tile_size_y * otype_size)
/ otype_size);
}
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/elements_per_load
// + 1/elements_per_load_dact
// + 1/elements_per_store_c
// + 1/elements_per_store_t) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->elements_per_load;
const auto &la1 = this->elements_per_load_dact;
const auto &sc1 = this->elements_per_store_c;
const auto &st1 = this->elements_per_store_t;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.elements_per_load;
const auto &la2 = other.elements_per_load_dact;
const auto &sc2 = other.elements_per_store_c;
const auto &st2 = other.elements_per_store_t;
const auto &p2 = other.active_sm_count;
const auto scale1 = l1 * sc1 * st1 * p1 * (is_dact ? la1 : 1);
const auto scale2 = l2 * sc2 * st2 * p2 * (is_dact ? la2 : 1);
const auto scale = scale1 * scale2;
const auto cost1 = (scale/l1 + scale/sc1 + scale/st1 + (is_dact ? (scale / la1) : 0))
/ p1;
const auto cost2 = (scale/l2 + scale/sc2 + scale/st2 + (is_dact ? (scale / la2) : 0))
/ p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
}
};
template <bool IS_DBIAS, bool IS_FULL_TILE, int nvec_in, int nvec_out,
typename OVec, typename CVec, typename CType>
inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in], OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile, typename OVec::type *output_cast_tile,
const size_t current_place, const size_t current_place,
const size_t stride, const size_t stride,
CType &max, // NOLINT(*)
const CType scale, const CType scale,
CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane,
const bool valid_store) { const bool valid_store) {
using T = typename OVec::type; using OType = typename OVec::type;
using OVecC = Vec<T, nvec_in>; using OVecC = Vec<OType, nvec_in>;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = static_cast<CType>(in[i].data.elt[j]);
const T elt_o = T(scale * tmp);
out_cast.data.elt[j] = elt_o;
out_trans[j].data.elt[i] = elt_o; // thread tile transpose
__builtin_assume(max >= 0); CVec step_dbias;
max = fmaxf(fabsf(tmp), max); if constexpr (IS_DBIAS) {
} step_dbias.clear();
if (full_tile || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
} }
}
template <bool full_tile, int nvec_in, int nvec_out,
typename IVec, typename OVec, typename CVec, typename CType>
inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in],
CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
CType &max, // NOLINT(*)
const CType scale,
const int dbias_shfl_src_lane,
const bool valid_store) {
using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>;
CVec step_dbias; step_dbias.clear();
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
...@@ -76,49 +165,30 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv ...@@ -76,49 +165,30 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j]; const CType tmp = in[i].data.elt[j];
const T elt_o = T(scale * tmp); if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
/* dbias: thread tile local accumulation */ }
step_dbias.data.elt[j] += tmp; out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
out_cast.data.elt[j] = elt_o;
out_trans[j].data.elt[i] = elt_o; // thread tile transpose
__builtin_assume(max >= 0); __builtin_assume(amax >= 0);
max = fmaxf(fabsf(tmp), max); amax = fmaxf(fabsf(tmp), amax);
} }
if (full_tile || valid_store) { if (IS_FULL_TILE || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i); out_cast.store_to(output_cast_tile, current_place + stride * i);
} }
} }
#pragma unroll if constexpr (IS_DBIAS) {
for (unsigned int j = 0; j < nvec_in; ++j) { #pragma unroll
CType elt = step_dbias.data.elt[j]; for (unsigned int j = 0; j < nvec_in; ++j) {
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in warp CType elt = step_dbias.data.elt[j];
out_dbias.data.elt[j] += elt; elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
out_dbias.data.elt[j] += elt;
}
} }
} }
namespace {
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDGeluParam {
using InputType = IType;
using InputType2 = IType2;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
const IType2 *act_input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *workspace;
};
} // namespace
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace, Tensor* workspace,
const int nvec_out) { const int nvec_out) {
...@@ -191,7 +261,8 @@ void reduce_dbias(const Tensor &workspace, ...@@ -191,7 +261,8 @@ void reduce_dbias(const Tensor &workspace,
const size_t reduce_dbias_num_blocks = DIVUP(row_length, const size_t reduce_dbias_num_blocks = DIVUP(row_length,
reduce_dbias_num_threads * reduce_dbias_nvec); reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, fp32, InputType> using DbiasOutputType = fp32;
reduce_dbias_kernel<reduce_dbias_nvec, DbiasOutputType, InputType>
<<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>> <<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>
(reinterpret_cast<InputType *>(dbias->data.dptr), (reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reinterpret_cast<const fp32 *>(workspace.data.dptr),
...@@ -200,181 +271,8 @@ void reduce_dbias(const Tensor &workspace, ...@@ -200,181 +271,8 @@ void reduce_dbias(const Tensor &workspace,
} }
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename Param,
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP, int nvec_in, int nvec_out, typename ParamOP,
int nvec_in, int nvec_out, typename Param,
ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_fused_kernel(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using IType2 = typename Param::InputType2;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using IVec2 = Vec<IType2, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset_out = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const size_t tile_offset_in = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset_out;
const IType2 * const my_act_input_tile = param.act_input + tile_offset_out;
OType * const my_output_c_tile = param.output_c + tile_offset_out;
OType * const my_output_t_tile = param.output_t + tile_offset_in;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
IVec2 act_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
const size_t ld_offset = current_stride + my_place + stride * i;
in[0][i].load_from(my_input_tile, ld_offset);
act_in[0][i].load_from(my_act_input_tile, ld_offset);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
}
CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
if constexpr (IS_DACT) {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
}
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
if constexpr (IS_DBIAS) {
const size_t dbias_shfl_src_lane =
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
cast_and_transpose_regs_partial_dbias<true>(after_dact, out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
dbias_shfl_src_lane,
true);
} else {
cast_and_transpose_regs<true>(after_dact, out_trans, my_output_c_tile,
current_place, stride, max, scale, true);
}
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max);
}
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
int nvec_in, int nvec_out, typename Param,
ComputeType (*OP)(ComputeType, const ParamOP&)> ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void __global__ void
__launch_bounds__(cast_transpose_num_threads) __launch_bounds__(cast_transpose_num_threads)
...@@ -395,10 +293,10 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -395,10 +293,10 @@ cast_transpose_fused_kernel_notaligned(const Param param,
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1)
(nvec_in * THREADS_PER_WARP); / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile)
warp_id / n_warps_per_tile; + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) { if (tile_id >= num_tiles) {
return; return;
} }
...@@ -406,17 +304,18 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -406,17 +304,18 @@ cast_transpose_fused_kernel_notaligned(const Param param,
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset_out = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) * const size_t tile_offset = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out)
THREADS_PER_WARP; * THREADS_PER_WARP;
const size_t tile_offset_in = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) * const size_t tile_offset_transp = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in)
THREADS_PER_WARP; * THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset_out; const IType * const my_input_tile = param.input + tile_offset;
const IType2 * const my_act_input_tile = param.act_input + tile_offset_out; const IType2 * const my_act_input_tile = param.act_input + tile_offset;
OType * const my_output_c_tile = param.output_c + tile_offset_out; OType * const my_output_c_tile = param.output_c + tile_offset;
OType * const my_output_t_tile = param.output_t + tile_offset_in; OType * const my_output_t_tile = param.output_t + tile_offset_transp;
CType * const my_partial_dbias_tile = param.workspace + CType * const my_partial_dbias_tile = param.workspace
(tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length); + (tile_id_x * (nvec_in * THREADS_PER_WARP)
+ tile_id_y * row_length);
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
...@@ -427,9 +326,9 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -427,9 +326,9 @@ cast_transpose_fused_kernel_notaligned(const Param param,
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest; : row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) + OVec * const my_scratch = reinterpret_cast<OVec *>(scratch)
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP)
(THREADS_PER_WARP + 1); * (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch); CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
...@@ -438,15 +337,15 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -438,15 +337,15 @@ cast_transpose_fused_kernel_notaligned(const Param param,
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in]; OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * nvec_out;
warp_id_in_tile * n_iterations) % unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
THREADS_PER_WARP; % THREADS_PER_WARP;
CType max = 0; CType amax = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
partial_dbias.clear(); partial_dbias.clear();
} }
...@@ -459,10 +358,14 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -459,10 +358,14 @@ cast_transpose_fused_kernel_notaligned(const Param param,
if (valid_load) { if (valid_load) {
const size_t ld_offset = current_stride + my_place + stride * i; const size_t ld_offset = current_stride + my_place + stride * i;
in[0][i].load_from(my_input_tile, ld_offset); in[0][i].load_from(my_input_tile, ld_offset);
act_in[0][i].load_from(my_act_input_tile, ld_offset); if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, ld_offset);
}
} else { } else {
in[0][i].clear(); in[0][i].clear();
act_in[0][i].clear(); if constexpr (IS_DACT) {
act_in[0][i].clear();
}
} }
} }
} }
...@@ -478,12 +381,16 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -478,12 +381,16 @@ cast_transpose_fused_kernel_notaligned(const Param param,
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) { if (valid_load) {
const size_t ld_offset = current_stride + my_place_in + stride*(nvec_out + j); const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
in[current_in][j].load_from(my_input_tile, ld_offset); in[current_in][j].load_from(my_input_tile, ld_offset);
act_in[current_in][j].load_from(my_act_input_tile, ld_offset); if constexpr (IS_DACT) {
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
} else { } else {
in[current_in][j].clear(); in[current_in][j].clear();
act_in[current_in][j].clear(); if constexpr (IS_DACT) {
act_in[current_in][j].clear();
}
} }
} }
} }
...@@ -493,49 +400,39 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -493,49 +400,39 @@ cast_transpose_fused_kernel_notaligned(const Param param,
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k])
OP(act_in[current_in ^ 1][j].data.elt[k], {}); * OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else { } else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
} }
} }
} }
OVec out_trans[nvec_in]; // NOLINT(*) const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations)
const bool valid_store = my_place < tile_length && % THREADS_PER_WARP;
warp_id_in_tile * n_iterations + i < tile_height; constexpr bool IS_FULL_TILE = false;
const bool valid_store = (my_place < tile_length)
if constexpr (IS_DBIAS) { && (warp_id_in_tile * n_iterations + i < tile_height);
const size_t dbias_shfl_src_lane =
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP; cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
cast_and_transpose_regs_partial_dbias<false>(after_dact, out_trans, (after_dact, out_space[i], partial_dbias, my_output_c_tile, current_place,
partial_dbias, my_output_c_tile, stride, scale, amax, dbias_shfl_src_lane, valid_store);
current_place, stride, max, scale,
dbias_shfl_src_lane,
valid_store);
} else {
cast_and_transpose_regs<false>(after_dact, out_trans, my_output_c_tile,
current_place, stride, max, scale, valid_store);
}
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride; current_stride += nvec_out * stride;
current_row += nvec_out;
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP - j - warp_id_in_tile * n_iterations)
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; % THREADS_PER_WARP] = out_space[j][i];
} }
__syncthreads(); __syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
THREADS_PER_WARP; % THREADS_PER_WARP;
current_stride = i * output_stride + current_stride = i * output_stride
warp_id_in_tile * n_iterations * output_stride * nvec_in; + warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height; const bool valid_store = my_place < tile_height;
if (valid_store) { if (valid_store) {
...@@ -551,7 +448,6 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -551,7 +448,6 @@ cast_transpose_fused_kernel_notaligned(const Param param,
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias; my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads(); __syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) { if (warp_id_in_tile == 0) {
#pragma unroll #pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) { for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
...@@ -568,16 +464,46 @@ cast_transpose_fused_kernel_notaligned(const Param param, ...@@ -568,16 +464,46 @@ cast_transpose_fused_kernel_notaligned(const Param param,
} }
/* warp tile amax reduce*/ /* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id); amax = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) { if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max); atomicMaxFloat(param.amax, amax);
} }
} }
} }
static const char* ActTypeToString[] = {
"NoAct", // 0
"Sigmoid", // 1
"GeLU", // 2
"QGeLU", // 3
"SiLU", // 4
"ReLU", // 5
"SReLU" // 6
};
template <typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
int get_dactivation_type() {
if (OP == &sigmoid<ComputeType, ComputeType>) {
return 1;
} else if (OP == &dgelu<ComputeType, ComputeType>) {
return 2;
} else if (OP == &dqgelu<ComputeType, ComputeType>) {
return 3;
} else if (OP == &dsilu<ComputeType, ComputeType>) {
return 4;
} else if (OP == &drelu<ComputeType, ComputeType>) {
return 5;
} else if (OP == &dsrelu<ComputeType, ComputeType>) {
return 6;
} else {
return 0;
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)> ComputeType (*OP)(ComputeType, const ParamOP&)>
void cast_transpose_fused(const Tensor &input, void cast_transpose_fused(const Tensor &input,
...@@ -587,6 +513,10 @@ void cast_transpose_fused(const Tensor &input, ...@@ -587,6 +513,10 @@ void cast_transpose_fused(const Tensor &input,
Tensor *dbias, Tensor *dbias,
Tensor *workspace, Tensor *workspace,
cudaStream_t stream) { cudaStream_t stream) {
CheckInputTensor(input, "cast_transpose_fused_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
...@@ -606,12 +536,14 @@ void cast_transpose_fused(const Tensor &input, ...@@ -606,12 +536,14 @@ void cast_transpose_fused(const Tensor &input,
"C and T outputs need to share scale tensor."); "C and T outputs need to share scale tensor.");
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, NVTE_CHECK(dbias->data.dtype == input.data.dtype,
"DBias must have the same type as input."); "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length },
"Wrong shape of DBias."); "Wrong shape of DBias.");
} }
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
CheckInputTensor(act_input, "act_input");
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match."); NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match."); NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
} }
...@@ -619,14 +551,64 @@ void cast_transpose_fused(const Tensor &input, ...@@ -619,14 +551,64 @@ void cast_transpose_fused(const Tensor &input,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType; using InputType2 = InputType;
/* dact fusion kernel uses more registers */ using Param = CTDBiasDActParam<InputType, InputType2, OutputType, ComputeType>;
constexpr int load_size = (IS_DACT ? 4 : 8);
constexpr int store_size = (IS_DACT ? 4 : 8);
constexpr int itype_size = sizeof(InputType); constexpr int itype_size = sizeof(InputType);
constexpr int itype2_size = sizeof(InputType2);
constexpr int otype_size = sizeof(OutputType); constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = load_size / itype_size;
constexpr int nvec_out = store_size / otype_size;
const bool aligned = (row_length % THREADS_PER_WARP == 0)
&& (num_rows % THREADS_PER_WARP == 0);
const bool jit_compiled = aligned && rtc::is_enabled();
size_t load_size = (IS_DACT ? desired_load_size_dact : desired_load_size);
size_t store_size = (IS_DACT ? desired_store_size_dact : desired_store_size);
size_t num_blocks;
if (jit_compiled) {
// Pick kernel config
std::vector<KernelConfig> kernel_configs;
kernel_configs.reserve(16);
auto add_config = [&](size_t load_size_config, size_t store_size_config) {
kernel_configs.emplace_back(row_length, num_rows,
itype_size, itype2_size, otype_size,
load_size_config, store_size_config,
IS_DACT);
};
add_config(8, 8);
add_config(4, 8); add_config(8, 4);
add_config(4, 4);
add_config(2, 8); add_config(8, 2);
add_config(2, 4); add_config(4, 2);
add_config(2, 2);
add_config(1, 8); add_config(8, 1);
add_config(1, 4); add_config(4, 1);
add_config(1, 2); add_config(2, 1);
add_config(1, 1);
// Select the kernel configuration with the lowest cost
const auto &kernel_config = *std::min_element(kernel_configs.begin(),
kernel_configs.end());
NVTE_CHECK(kernel_config.valid, "invalid kernel config");
load_size = kernel_config.load_size;
store_size = kernel_config.store_size;
num_blocks = kernel_config.num_blocks;
}
const size_t nvec_in = load_size / itype_size;
const size_t nvec_out = store_size / otype_size;
const size_t tile_size_x = nvec_in * threads_per_warp;
const size_t tile_size_y = nvec_out * threads_per_warp;
const size_t num_tiles_x = DIVUP(row_length, tile_size_x);
const size_t num_tiles_y = DIVUP(num_rows, tile_size_y);
const size_t num_tiles = num_tiles_x * num_tiles_y;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
if (!jit_compiled) {
num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block);
}
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
if (workspace->data.dptr == nullptr) { if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, populate_cast_transpose_dbias_workspace_config(*cast_output,
...@@ -635,44 +617,37 @@ void cast_transpose_fused(const Tensor &input, ...@@ -635,44 +617,37 @@ void cast_transpose_fused(const Tensor &input,
} }
} }
CheckInputTensor(input, "cast_transpose_fused_input"); size_t VecOutputTypeSize;
CheckOutputTensor(*cast_output, "cast_output"); switch (nvec_out) {
CheckOutputTensor(*transposed_output, "transposed_output"); case 1: VecOutputTypeSize = sizeof(Vec<OutputType, 1>); break;
if constexpr (IS_DBIAS) { case 2: VecOutputTypeSize = sizeof(Vec<OutputType, 2>); break;
CheckOutputTensor(*dbias, "dbias"); case 4: VecOutputTypeSize = sizeof(Vec<OutputType, 4>); break;
} case 8: VecOutputTypeSize = sizeof(Vec<OutputType, 8>); break;
if constexpr (IS_DACT) {
CheckInputTensor(act_input, "act_input");
} }
size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); * (threads_per_warp + 1) * VecOutputTypeSize;
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles =
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
// using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
constexpr size_t shared_size_dbias = size_t VecComputeTypeSize;
cast_transpose_num_threads * sizeof(Vec<ComputeType, nvec_in>); switch (nvec_in) {
static_assert(shared_size_transpose >= shared_size_dbias); case 1: VecComputeTypeSize = sizeof(Vec<ComputeType, 1>); break;
case 2: VecComputeTypeSize = sizeof(Vec<ComputeType, 2>); break;
case 4: VecComputeTypeSize = sizeof(Vec<ComputeType, 4>); break;
case 8: VecComputeTypeSize = sizeof(Vec<ComputeType, 8>); break;
}
const size_t shared_size_dbias = cast_transpose_num_threads * VecComputeTypeSize;
if (shared_size_transpose < shared_size_dbias) {
shared_size_transpose = shared_size_dbias;
}
} }
using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>;
Param param; Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr); param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr); param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr); param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr); param.scale_ptr = reinterpret_cast<const ComputeType *>(transposed_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr); param.amax = reinterpret_cast<ComputeType *>(transposed_output->amax.dptr);
param.scale_inv = reinterpret_cast<ComputeType *>(cast_output->scale_inv.dptr);
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr); param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
} }
...@@ -680,26 +655,75 @@ void cast_transpose_fused(const Tensor &input, ...@@ -680,26 +655,75 @@ void cast_transpose_fused(const Tensor &input,
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr); param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
} }
if (full_tile) { // Runtime-compiled tuned kernel
cudaFuncSetAttribute( if (jit_compiled) {
cast_transpose_fused_kernel constexpr const char *itype_name = TypeInfo<InputType>::name;
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP>, constexpr const char *itype2_name = TypeInfo<InputType2>::name;
cudaFuncAttributePreferredSharedMemoryCarveout, constexpr const char *otype_name = TypeInfo<OutputType>::name;
100);
cast_transpose_fused_kernel int dActType = 0;
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP> if constexpr (IS_DACT) {
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>> dActType = get_dactivation_type<ComputeType, ParamOP, OP>();
(param, row_length, num_rows, n_tiles); }
} else {
// Compile NVRTC kernel if needed and launch
auto& rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label =
concat_strings("cast_transpose_fusion"
",itype=", itype_name,
",itype2=", itype2_name,
",otype=", otype_name,
",load_size=", load_size,
",store_size=", store_size,
",IS_DBIAS=", IS_DBIAS,
",IS_DACT=", IS_DACT,
",dactivationType=", ActTypeToString[dActType]);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu;
code = regex_replace(code, "__ITYPE__", itype_name);
code = regex_replace(code, "__ITYPE2__", itype2_name);
code = regex_replace(code, "__OTYPE__", otype_name);
code = regex_replace(code, "__LOAD_SIZE__", load_size);
code = regex_replace(code, "__STORE_SIZE__", store_size);
code = regex_replace(code, "__WARPS_PER_TILE__", n_warps_per_tile);
code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads);
code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS);
code = regex_replace(code, "__IS_DACT__", IS_DACT);
code = regex_replace(code, "__DACTIVATION_TYPE__", dActType);
rtc_manager.compile(
kernel_label,
"cast_transpose_fusion_kernel_optimized",
code,
"transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu");
}
rtc_manager.set_cache_config(kernel_label, CU_FUNC_CACHE_PREFER_SHARED);
rtc_manager.launch(kernel_label,
num_blocks, cast_transpose_num_threads, shared_size_transpose, stream,
param, row_length, num_rows, num_tiles);
} else { // Statically-compiled general kernel
constexpr size_t load_size = IS_DACT ? desired_load_size_dact :
desired_load_size;
constexpr size_t store_size = IS_DACT ? desired_store_size_dact :
desired_store_size;
constexpr size_t nvec_in = load_size / itype_size;
constexpr size_t nvec_out = store_size / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
cudaFuncSetAttribute( cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP>, <IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout,
100); 100);
cast_transpose_fused_kernel_notaligned cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP> <IS_DBIAS, IS_DACT, ComputeType, Param, nvec_in, nvec_out, Empty, OP>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>> <<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>
(param, row_length, num_rows, n_tiles); (param, row_length, num_rows, num_tiles);
} }
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
...@@ -734,8 +758,8 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -734,8 +758,8 @@ dgated_act_cast_transpose_kernel(const IType * const input,
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile)
warp_id / n_warps_per_tile; + warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) { if (tile_id >= num_tiles) {
return; return;
} }
...@@ -786,6 +810,9 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -786,6 +810,9 @@ dgated_act_cast_transpose_kernel(const IType * const input,
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
CType max = 0; CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
...@@ -822,11 +849,21 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -822,11 +849,21 @@ dgated_act_cast_transpose_kernel(const IType * const input,
} }
} }
OVec out_trans_0[nvec_in]; // NOLINT(*) OVec out_trans_0[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(after_dact, out_trans_0, my_output_c_tile_0,
current_place, stride2, max, scale, true);
OVec out_trans_1[nvec_in]; // NOLINT(*) OVec out_trans_1[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(after_dgate, out_trans_1, my_output_c_tile_1,
current_place, stride2, max, scale, true); constexpr bool IS_DBIAS = false;
constexpr bool IS_FULL_TILE = true;
constexpr bool valid_store = true;
constexpr int dbias_shfl_src_lane = 0;
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec; out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
...@@ -896,15 +933,15 @@ template <int nvec_in, int nvec_out, ...@@ -896,15 +933,15 @@ template <int nvec_in, int nvec_out,
__global__ void __global__ void
__launch_bounds__(cast_transpose_num_threads) __launch_bounds__(cast_transpose_num_threads)
dgated_act_cast_transpose_kernel_notaligned(const IType * const input, dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
const IType * const act_input, const IType * const act_input,
OType * const output_c, OType * const output_c,
OType * const output_t, OType * const output_t,
const CType * const scale_ptr, const CType * const scale_ptr,
CType * const amax, CType * const amax,
CType * const scale_inv, CType * const scale_inv,
const size_t row_length, const size_t row_length,
const size_t num_rows, const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>; using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>; using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>; using CVec = Vec<CType, nvec_in>;
...@@ -971,6 +1008,9 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -971,6 +1008,9 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
CVec partial_dbias;
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height; warp_id_in_tile * n_iterations < tile_height;
...@@ -1031,12 +1071,20 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1031,12 +1071,20 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
OVec out_trans_0[nvec_in]; // NOLINT(*) OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*) OVec out_trans_1[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height; constexpr bool IS_DBIAS = false;
cast_and_transpose_regs<false>(after_dact, out_trans_0, my_output_c_tile_0, constexpr bool IS_FULL_TILE = false;
current_place, stride2, max, scale, valid_store); constexpr int dbias_shfl_src_lane = 0;
cast_and_transpose_regs<false>(after_dgate, out_trans_1, my_output_c_tile_1, const bool valid_store = (my_place < tile_length)
current_place, stride2, max, scale, valid_store); && (warp_id_in_tile * n_iterations + i < tile_height);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dact, out_trans_0, partial_dbias, my_output_c_tile_0, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
cast_and_transpose_regs<IS_DBIAS, IS_FULL_TILE>
(after_dgate, out_trans_1, partial_dbias, my_output_c_tile_1, current_place, stride2,
scale, max, dbias_shfl_src_lane, valid_store);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec; out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
...@@ -1204,9 +1252,13 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1204,9 +1252,13 @@ void dgated_act_cast_transpose(const Tensor &input,
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
} // namespace
} // namespace transformer_engine } // namespace transformer_engine
using ComputeType = typename transformer_engine::fp32;
void nvte_cast_transpose_dbias(const NVTETensor input, void nvte_cast_transpose_dbias(const NVTETensor input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
...@@ -1221,7 +1273,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -1221,7 +1273,7 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
constexpr const NVTETensor activation_input = nullptr; constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, nullptr>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, nullptr>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(activation_input), *reinterpret_cast<const Tensor*>(activation_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1246,7 +1298,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -1246,7 +1298,7 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
constexpr auto dActivation = &dgelu<fp32, fp32>; constexpr auto dActivation = &dgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(act_input), *reinterpret_cast<const Tensor*>(act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1271,7 +1323,7 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, ...@@ -1271,7 +1323,7 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
constexpr auto dActivation = &dsilu<fp32, fp32>; constexpr auto dActivation = &dsilu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(silu_input), *reinterpret_cast<const Tensor*>(silu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1296,7 +1348,7 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, ...@@ -1296,7 +1348,7 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
constexpr auto dActivation = &drelu<fp32, fp32>; constexpr auto dActivation = &drelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(relu_input), *reinterpret_cast<const Tensor*>(relu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1321,7 +1373,7 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, ...@@ -1321,7 +1373,7 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
constexpr auto dActivation = &dsrelu<fp32, fp32>; constexpr auto dActivation = &dsrelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(srelu_input), *reinterpret_cast<const Tensor*>(srelu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1346,7 +1398,7 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, ...@@ -1346,7 +1398,7 @@ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
constexpr auto dActivation = &dqgelu<fp32, fp32>; constexpr auto dActivation = &dqgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>( cast_transpose_fused<IS_DBIAS, IS_DACT, ComputeType, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(qgelu_input), *reinterpret_cast<const Tensor*>(qgelu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1367,7 +1419,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, ...@@ -1367,7 +1419,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input,
constexpr auto dActivation = &dgelu<fp32, fp32>; constexpr auto dActivation = &dgelu<fp32, fp32>;
constexpr auto Activation = &gelu<fp32, fp32>; constexpr auto Activation = &gelu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1386,7 +1438,7 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, ...@@ -1386,7 +1438,7 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input,
constexpr auto dActivation = &dsilu<fp32, fp32>; constexpr auto dActivation = &dsilu<fp32, fp32>;
constexpr auto Activation = &silu<fp32, fp32>; constexpr auto Activation = &silu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(swiglu_input), *reinterpret_cast<const Tensor*>(swiglu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1405,7 +1457,7 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, ...@@ -1405,7 +1457,7 @@ void nvte_dreglu_cast_transpose(const NVTETensor input,
constexpr auto dActivation = &drelu<fp32, fp32>; constexpr auto dActivation = &drelu<fp32, fp32>;
constexpr auto Activation = &relu<fp32, fp32>; constexpr auto Activation = &relu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1424,7 +1476,7 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, ...@@ -1424,7 +1476,7 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input,
constexpr auto dActivation = &dsrelu<fp32, fp32>; constexpr auto dActivation = &dsrelu<fp32, fp32>;
constexpr auto Activation = &srelu<fp32, fp32>; constexpr auto Activation = &srelu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1443,7 +1495,7 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, ...@@ -1443,7 +1495,7 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input,
constexpr auto dActivation = &dqgelu<fp32, fp32>; constexpr auto dActivation = &dqgelu<fp32, fp32>;
constexpr auto Activation = &qgelu<fp32, fp32>; constexpr auto Activation = &qgelu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>( dgated_act_cast_transpose<ComputeType, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.cuh"
#include "util/math.h"
using namespace transformer_engine;
namespace {
// Parameters
using CType = float;
using IType = __ITYPE__;
using IType2 = __ITYPE2__;
using OType = __OTYPE__;
constexpr size_t LOAD_SIZE = __LOAD_SIZE__;
constexpr size_t STORE_SIZE = __STORE_SIZE__;
constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__;
constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__;
constexpr bool IS_DBIAS = __IS_DBIAS__;
constexpr bool IS_DACT = __IS_DACT__;
constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__;
constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType);
constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType);
using CVec = Vec<CType, NVEC_IN>;
using IVec = Vec<IType, NVEC_IN>;
using IVec2 = Vec<IType2, NVEC_IN>;
using OVec = Vec<OType, NVEC_OUT>;
using Param = CTDBiasDActParam<IType, IType2, OType, CType>;
using OP = CType (*)(const CType, const Empty&);
constexpr OP Activation[] = {
nullptr, // 0
&dsigmoid<CType, CType>, // 1
&dgelu<CType, CType>, // 2
&dqgelu<CType, CType>, // 3
&dsilu<CType, CType>, // 4
&drelu<CType, CType>, // 5
&dsrelu<CType, CType> // 6
};
} // namespace
inline __device__ void
cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_OUT],
OVec (&out_trans)[NVEC_IN],
CVec &out_dbias, // NOLINT(*)
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
const CType scale,
CType &amax, // NOLINT(*)
const int dbias_shfl_src_lane) {
using OVecC = Vec<OType, NVEC_IN>;
CVec step_dbias;
if constexpr (IS_DBIAS) {
step_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
const CType tmp = in[i].data.elt[j];
if constexpr (IS_DBIAS) {
step_dbias.data.elt[j] += tmp; // dbias: thread tile local accumulation
}
out_cast.data.elt[j] = static_cast<OType>(tmp * scale);
out_trans[j].data.elt[i] = static_cast<OType>(tmp * scale); // thread tile transpose
__builtin_assume(amax >= 0);
amax = fmaxf(fabsf(tmp), amax);
}
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
if constexpr (IS_DBIAS) {
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in a warp
out_dbias.data.elt[j] += elt;
}
}
}
__global__ void
__launch_bounds__(BLOCK_SIZE)
cast_transpose_fusion_kernel_optimized(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (NVEC_IN * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * WARPS_PER_TILE)
+ warp_id / WARPS_PER_TILE;
if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const size_t tile_offset = (tile_id_x * NVEC_IN + tile_id_y * row_length * NVEC_OUT)
* THREADS_PER_WARP;
const size_t tile_offset_transp = (tile_id_y * NVEC_OUT + tile_id_x * num_rows * NVEC_IN)
* THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset;
const IType2 * const my_act_input_tile = param.act_input + tile_offset;
OType * const my_output_c_tile = param.output_c + tile_offset;
OType * const my_output_t_tile = param.output_t + tile_offset_transp;
CType * const my_partial_dbias_tile = param.workspace
+ (tile_id_x * (NVEC_IN * THREADS_PER_WARP)
+ tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch)
+ (my_id_in_warp + warp_id / WARPS_PER_TILE * THREADS_PER_WARP)
* (THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][NVEC_OUT];
IVec2 act_in[2][NVEC_OUT];
const unsigned int warp_id_in_tile = warp_id % WARPS_PER_TILE;
constexpr unsigned int n_iterations = THREADS_PER_WARP / WARPS_PER_TILE;
OVec out_space[n_iterations][NVEC_IN];
const size_t stride = row_length / NVEC_IN;
const size_t output_stride = num_rows / NVEC_OUT;
size_t current_stride = warp_id_in_tile * n_iterations * NVEC_OUT * stride;
size_t current_row = (tile_id_y * THREADS_PER_WARP + warp_id_in_tile * n_iterations) * NVEC_OUT;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
CType amax = 0.0f;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
CVec partial_dbias;
if constexpr (IS_DBIAS) {
partial_dbias.clear();
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_OUT; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
if constexpr (IS_DACT) {
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i);
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
const size_t ld_offset = current_stride + my_place_in + stride * (NVEC_OUT + j);
in[current_in][j].load_from(my_input_tile, ld_offset);
if constexpr (IS_DACT) {
act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
}
}
}
CVec in_cast_fp32[NVEC_OUT]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < NVEC_OUT; ++j) {
#pragma unroll
for (unsigned int k = 0; k < NVEC_IN; ++k) {
if constexpr (IS_DACT) {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k])
* Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {});
} else {
in_cast_fp32[j].data.elt[k] =
static_cast<CType>(in[current_in ^ 1][j].data.elt[k]);
}
}
}
const int dbias_shfl_src_lane = (my_id_in_warp + i + warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
cast_and_transpose_regs_optimized(in_cast_fp32, out_space[i], partial_dbias,
my_output_c_tile, current_place,
stride, scale, amax, dbias_shfl_src_lane);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += NVEC_OUT * stride;
current_row += NVEC_OUT;
}
#pragma unroll
for (unsigned int i = 0; i < NVEC_IN; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP
- j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations)
% THREADS_PER_WARP;
current_stride = i * output_stride
+ warp_id_in_tile * n_iterations * output_stride * NVEC_IN;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * NVEC_IN;
}
__syncthreads();
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < WARPS_PER_TILE; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < NVEC_IN; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
}
// warp tile amax reduce
const CType max_block = reduce_max<BLOCK_SIZE/THREADS_PER_WARP>(amax, warp_id);
if (threadIdx.x == 0) {
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max_block);
}
}
}
...@@ -23,6 +23,7 @@ namespace { ...@@ -23,6 +23,7 @@ namespace {
// Strings with headers for RTC kernels // Strings with headers for RTC kernels
#include "string_code_utils_cuh.h" #include "string_code_utils_cuh.h"
#include "string_code_util_math_h.h"
/*! \brief Latest compute capability that NVRTC supports /*! \brief Latest compute capability that NVRTC supports
* *
...@@ -136,6 +137,10 @@ CUfunction Kernel::get_function(int device_id) { ...@@ -136,6 +137,10 @@ CUfunction Kernel::get_function(int device_id) {
return functions_[device_id]; return functions_[device_id];
} }
void Kernel::set_function_cache_config(int device_id, CUfunc_cache cache_config) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuFuncSetCacheConfig, get_function(device_id), cache_config);
}
KernelManager& KernelManager::instance() { KernelManager& KernelManager::instance() {
NVTE_CHECK(is_enabled(), "NVRTC support is not enabled"); NVTE_CHECK(is_enabled(), "NVRTC support is not enabled");
static KernelManager instance_; static KernelManager instance_;
...@@ -173,9 +178,9 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -173,9 +178,9 @@ void KernelManager::compile(const std::string &kernel_label,
// Compile source // Compile source
nvrtcProgram program; nvrtcProgram program;
constexpr int num_headers = 1; constexpr int num_headers = 2;
constexpr const char* headers[num_headers] = {string_code_utils_cuh}; constexpr const char* headers[num_headers] = {string_code_utils_cuh, string_code_util_math_h};
constexpr const char* include_names[num_headers] = {"utils.cuh"}; constexpr const char* include_names[num_headers] = {"utils.cuh", "util/math.h"};
NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program, NVTE_CHECK_NVRTC(nvrtcCreateProgram(&program,
code.c_str(), code.c_str(),
filename.c_str(), filename.c_str(),
...@@ -229,6 +234,14 @@ void KernelManager::compile(const std::string &kernel_label, ...@@ -229,6 +234,14 @@ void KernelManager::compile(const std::string &kernel_label,
NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program)); NVTE_CHECK_NVRTC(nvrtcDestroyProgram(&program));
} }
void KernelManager::set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config) {
const int device_id = cuda::current_device();
const auto key = get_kernel_cache_key(kernel_label, device_id);
NVTE_CHECK(kernel_cache_.count(key) > 0,
"Attempted to configure RTC kernel before compilation");
kernel_cache_.at(key).set_function_cache_config(device_id, cache_config);
}
bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const { bool KernelManager::is_compiled(const std::string &kernel_label, int device_id) const {
const auto key = get_kernel_cache_key(kernel_label, device_id); const auto key = get_kernel_cache_key(kernel_label, device_id);
return kernel_cache_.count(key) > 0; return kernel_cache_.count(key) > 0;
......
...@@ -85,6 +85,12 @@ class Kernel { ...@@ -85,6 +85,12 @@ class Kernel {
*/ */
CUfunction get_function(int device_id); CUfunction get_function(int device_id);
/*! \brief Sets the preferred cache configuration for a function
*
* Wrapper of the CUDA Driver API function "cuFuncSetCacheConfig"
*/
void set_function_cache_config(int device_id, CUfunc_cache cache_config);
private: private:
/*! \brief Mangled function name */ /*! \brief Mangled function name */
std::string mangled_name_; std::string mangled_name_;
...@@ -166,6 +172,15 @@ class KernelManager { ...@@ -166,6 +172,15 @@ class KernelManager {
std::forward<ArgTs>(args)...); std::forward<ArgTs>(args)...);
} }
/*! \brief Sets the preferred cache configuration for a function in the context
*
* Assumes the kernel has already been compiled.
*
* \param[in] kernel_label Unique identifying string for kernel
* \param[in] cache_config Prefered cache configuration
*/
void set_cache_config(const std::string &kernel_label, CUfunc_cache cache_config);
private: private:
/*! \brief Compiled kernels */ /*! \brief Compiled kernels */
std::unordered_map<std::string, Kernel> kernel_cache_; std::unordered_map<std::string, Kernel> kernel_cache_;
......
...@@ -163,6 +163,25 @@ struct TypeToVec2<nv_bfloat16> { ...@@ -163,6 +163,25 @@ struct TypeToVec2<nv_bfloat16> {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDActParam {
using InputType = IType;
using InputType2 = IType2;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
const IType2 *act_input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *scale_inv;
CType *workspace;
CType *warp_scales_inv;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX> template<int INDEX>
struct Get { struct Get {
template<typename T, typename R> template<typename T, typename R>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment