Unverified Commit fc989613 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Fused cast transpose kernels refactoring (#884)



* Merged CT+dbias+dact into a single template
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Moved gated activations ifrom the cast_transpose_fused ito a sseparate cpp file
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Code clean up
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Update transformer_engine/common/transpose/cast_transpose_fusion.cu
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Update transformer_engine/common/transpose/cast_transpose_fusion.cu
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>

* Reverted the change with the file split
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: default avatarOleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 868c7d30
...@@ -15,6 +15,13 @@ ...@@ -15,6 +15,13 @@
namespace transformer_engine { namespace transformer_engine {
// STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 8;
constexpr unsigned int max_threads_per_block = 256;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block);
constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP;
constexpr size_t reduce_dbias_num_threads = 256;
template <bool full_tile, int nvec_in, int nvec_out, typename IVec, typename OVec, typename CType> template <bool full_tile, int nvec_in, int nvec_out, typename IVec, typename OVec, typename CType>
inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in], OVec (&out_trans)[nvec_in],
...@@ -26,10 +33,10 @@ inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], ...@@ -26,10 +33,10 @@ inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out],
const bool valid_store) { const bool valid_store) {
using T = typename OVec::type; using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>; using OVecC = Vec<T, nvec_in>;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast; OVecC out_cast;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = static_cast<CType>(in[i].data.elt[j]); const CType tmp = static_cast<CType>(in[i].data.elt[j]);
const T elt_o = T(scale * tmp); const T elt_o = T(scale * tmp);
...@@ -63,10 +70,10 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv ...@@ -63,10 +70,10 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv
CVec step_dbias; step_dbias.clear(); CVec step_dbias; step_dbias.clear();
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast; OVecC out_cast;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = in[i].data.elt[j]; const CType tmp = in[i].data.elt[j];
const T elt_o = T(scale * tmp); const T elt_o = T(scale * tmp);
...@@ -85,7 +92,7 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv ...@@ -85,7 +92,7 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv
} }
} }
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
CType elt = step_dbias.data.elt[j]; CType elt = step_dbias.data.elt[j];
elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in warp elt = __shfl_sync(0xffffffff, elt, dbias_shfl_src_lane); // shuffle data in warp
...@@ -93,30 +100,8 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv ...@@ -93,30 +100,8 @@ inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nv
} }
} }
// STUFF TO TUNE
constexpr unsigned int n_warps_per_tile = 4;
constexpr int desired_load_size = 8;
constexpr int desired_store_size = 8;
constexpr unsigned int max_threads_per_block = 256;
static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block);
constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP;
namespace { namespace {
template <typename IType, typename OType, typename CType>
struct CTDBiasParam {
using InputType = IType;
using OutputType = OType;
using ComputeType = CType;
const IType *input;
OType *output_c;
OType *output_t;
const CType *scale_ptr;
CType *amax;
CType *workspace;
};
template <typename IType, typename IType2, typename OType, typename CType> template <typename IType, typename IType2, typename OType, typename CType>
struct CTDBiasDGeluParam { struct CTDBiasDGeluParam {
using InputType = IType; using InputType = IType;
...@@ -134,316 +119,21 @@ struct CTDBiasDGeluParam { ...@@ -134,316 +119,21 @@ struct CTDBiasDGeluParam {
} // namespace } // namespace
template <int nvec_in, int nvec_out, typename Param> void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
__global__ void Tensor* workspace,
__launch_bounds__(cast_transpose_num_threads) const int nvec_out) {
cast_transpose_dbias_kernel(const Param param, const size_t row_length = cast_output.data.shape[1];
const size_t row_length, const size_t num_rows = cast_output.data.shape[0];
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear();
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
cast_and_transpose_regs_partial_dbias<true>(in[current_in ^ 1], out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
true);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
}
}
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_kernel_notaligned(const Param param,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IType = typename Param::InputType;
using OType = typename Param::OutputType;
using CType = typename Param::ComputeType;
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const unsigned int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) +
tile_id_y * row_length);
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
CVec * const my_dbias_scratch = reinterpret_cast<CVec *>(scratch);
IVec in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space[n_iterations][nvec_in];
CVec partial_dbias;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
partial_dbias.clear();
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
} else {
in[0][i].clear();
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
} else {
in[current_in][j].clear();
}
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs_partial_dbias<false>(in[current_in ^ 1], out_trans,
partial_dbias, my_output_c_tile,
current_place, stride, max, scale,
(my_id_in_warp + i +
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
if (my_id_in_warp < tile_length) { const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
}
}
/* warp tile amax reduce*/ const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { workspace->data.shape = {num_rows_partial_dbias, row_length};
static_assert(std::is_same<CType, float>::value); workspace->data.dtype = DType::kFloat32;
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
}
} }
constexpr size_t reduce_dbias_num_threads = 256;
template<int nvec, typename ComputeType, typename OutputType> template<int nvec, typename ComputeType, typename OutputType>
__global__ void __global__ void
__launch_bounds__(reduce_dbias_num_threads) __launch_bounds__(reduce_dbias_num_threads)
...@@ -456,7 +146,9 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -456,7 +146,9 @@ reduce_dbias_kernel(OutputType* const dbias_output,
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; const int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= row_length) return; if (thread_id * nvec >= row_length) {
return;
}
const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec; const ComputeType* const thread_in_base = dbias_partial + thread_id * nvec;
OutputType* const thread_out_base = dbias_output + thread_id * nvec; OutputType* const thread_out_base = dbias_output + thread_id * nvec;
...@@ -467,38 +159,26 @@ reduce_dbias_kernel(OutputType* const dbias_output, ...@@ -467,38 +159,26 @@ reduce_dbias_kernel(OutputType* const dbias_output,
ComputeVec acc_vec; acc_vec.clear(); ComputeVec acc_vec; acc_vec.clear();
for (int i = 0; i < num_rows; ++i) { for (int i = 0; i < num_rows; ++i) {
ldg_vec.load_from(thread_in_base, i * stride_in_vec); ldg_vec.load_from(thread_in_base, i * stride_in_vec);
#pragma unroll #pragma unroll
for (int e = 0; e < nvec; ++e) { for (int e = 0; e < nvec; ++e) {
acc_vec.data.elt[e] += ldg_vec.data.elt[e]; acc_vec.data.elt[e] += ldg_vec.data.elt[e];
} }
} }
OutputVec stg_vec; OutputVec stg_vec;
#pragma unroll #pragma unroll
for (int e = 0; e < nvec; ++e) { for (int e = 0; e < nvec; ++e) {
stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]); stg_vec.data.elt[e] = OutputType(acc_vec.data.elt[e]);
} }
stg_vec.store_to(thread_out_base, 0); stg_vec.store_to(thread_out_base, 0);
} }
void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/
Tensor* workspace,
const int nvec_out) {
const size_t row_length = cast_output.data.shape[1];
const size_t num_rows = cast_output.data.shape[0];
const size_t tile_size_y = (nvec_out * THREADS_PER_WARP);
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y);
workspace->data.shape = {num_rows_partial_dbias, row_length};
workspace->data.dtype = DType::kFloat32;
}
template <typename InputType> template <typename InputType>
void reduce_dbias(const Tensor &workspace, Tensor *dbias, void reduce_dbias(const Tensor &workspace,
const size_t row_length, const size_t num_rows, const int nvec_out, Tensor *dbias,
const size_t row_length,
const size_t num_rows,
const int nvec_out,
cudaStream_t stream) { cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64 constexpr int reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType); constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(InputType);
...@@ -507,127 +187,26 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, ...@@ -507,127 +187,26 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias,
const size_t reduce_dbias_row_length = row_length; const size_t reduce_dbias_row_length = row_length;
const size_t reduce_dbias_num_rows = DIVUP(num_rows, const size_t reduce_dbias_num_rows = DIVUP(num_rows,
static_cast<size_t>(nvec_out * static_cast<size_t>(nvec_out * THREADS_PER_WARP));
THREADS_PER_WARP));
const size_t reduce_dbias_num_blocks = DIVUP(row_length, const size_t reduce_dbias_num_blocks = DIVUP(row_length,
reduce_dbias_num_threads * reduce_dbias_nvec); reduce_dbias_num_threads * reduce_dbias_nvec);
reduce_dbias_kernel<reduce_dbias_nvec, fp32, InputType> reduce_dbias_kernel<reduce_dbias_nvec, fp32, InputType>
<<<reduce_dbias_num_blocks, <<<reduce_dbias_num_blocks, reduce_dbias_num_threads, 0, stream>>>
reduce_dbias_num_threads, (reinterpret_cast<InputType *>(dbias->data.dptr),
0,
stream>>>(
reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reinterpret_cast<const fp32 *>(workspace.data.dptr),
reduce_dbias_row_length, reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
} }
void cast_transpose_dbias(const Tensor &input,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
if (workspace->data.dptr != nullptr) {
CheckInputTensor(input, "cast_transpose_dbias_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
}
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size / itype_size;
constexpr int nvec_out = desired_store_size / otype_size;
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
}
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) *
sizeof(Vec<OutputType, nvec_out>);
constexpr size_t shared_size_dbias = cast_transpose_num_threads *
sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasParam<InputType, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(cast_transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
}
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
); // NOLINT(*)
); // NOLINT(*)
}
// TODO Phuong: Change all the names in these generalized functions.
// For now, I keep the old names so that it is easier to do code review
template <typename ComputeType, typename ParamOP,
int nvec_in, int nvec_out, typename Param, int nvec_in, int nvec_out, typename Param,
ComputeType (*OP)(ComputeType, const ParamOP&)> ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void __global__ void
__launch_bounds__(cast_transpose_num_threads) __launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dact_kernel(const Param param, cast_transpose_fused_kernel(const Param param,
const size_t row_length, const size_t row_length,
const size_t num_rows, const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
...@@ -648,26 +227,23 @@ cast_transpose_dbias_dact_kernel(const Param param, ...@@ -648,26 +227,23 @@ cast_transpose_dbias_dact_kernel(const Param param,
// const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP); // const size_t num_tiles_y = num_rows / (nvec * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile; warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + const size_t tile_offset_out = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) *
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType2 * const my_act_input_tile = param.act_input +
(tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP; THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in + const size_t tile_offset_in = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) *
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP; THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset_out;
const IType2 * const my_act_input_tile = param.act_input + tile_offset_out;
OType * const my_output_c_tile = param.output_c + tile_offset_out;
OType * const my_output_t_tile = param.output_t + tile_offset_in;
CType * const my_partial_dbias_tile = param.workspace + CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
tile_id_y * row_length);
OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) + OVec * const my_scratch = reinterpret_cast<OVec *>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
...@@ -685,53 +261,63 @@ cast_transpose_dbias_dact_kernel(const Param param, ...@@ -685,53 +261,63 @@ cast_transpose_dbias_dact_kernel(const Param param,
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP; THREADS_PER_WARP;
CType max = 0; CType max = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
if constexpr (IS_DBIAS) {
partial_dbias.clear(); partial_dbias.clear();
}
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); const size_t ld_offset = current_stride + my_place + stride * i;
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, ld_offset);
act_in[0][i].load_from(my_act_input_tile, ld_offset);
} }
#pragma unroll
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place; const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile, const size_t ld_offset = current_stride + my_place_in + stride * (nvec_out + j);
current_stride + my_place_in + stride * (nvec_out + j)); in[current_in][j].load_from(my_input_tile, ld_offset);
act_in[current_in][j].load_from(my_act_input_tile, act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
current_stride + my_place_in +
stride * (nvec_out + j));
} }
} }
CVec after_dact[nvec_out]; // NOLINT(*) CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP(act_in[current_in ^ 1][j].data.elt[k], {}) * if constexpr (IS_DACT) {
CType(in[current_in ^ 1][j].data.elt[k]); after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
}
} }
} }
OVec out_trans[nvec_in]; // NOLINT(*) OVec out_trans[nvec_in]; // NOLINT(*)
if constexpr (IS_DBIAS) {
const size_t dbias_shfl_src_lane =
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
cast_and_transpose_regs_partial_dbias<true>(after_dact, out_trans, cast_and_transpose_regs_partial_dbias<true>(after_dact, out_trans,
partial_dbias, my_output_c_tile, partial_dbias, my_output_c_tile,
current_place, stride, max, scale, current_place, stride, max, scale,
(my_id_in_warp + i + dbias_shfl_src_lane,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
true); true);
} else {
cast_and_transpose_regs<true>(after_dact, out_trans, my_output_c_tile,
current_place, stride, max, scale, true);
}
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec; out_space[i][j].data.vec = out_trans[j].data.vec;
} }
...@@ -740,7 +326,7 @@ cast_transpose_dbias_dact_kernel(const Param param, ...@@ -740,7 +326,7 @@ cast_transpose_dbias_dact_kernel(const Param param,
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
...@@ -759,37 +345,40 @@ cast_transpose_dbias_dact_kernel(const Param param, ...@@ -759,37 +345,40 @@ cast_transpose_dbias_dact_kernel(const Param param,
__syncthreads(); __syncthreads();
} }
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias; my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads(); __syncthreads();
// TODO(ptredak): check if the regular reduction is better // TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) { if (warp_id_in_tile == 0) {
#pragma unroll #pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) { for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP]; CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j]; partial_dbias.data.elt[j] += tmp.data.elt[j];
} }
} }
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
} }
}
/* warp tile amax reduce*/ /* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max); if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max);
}
} }
} }
template <typename ComputeType, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
int nvec_in, int nvec_out, typename Param, int nvec_in, int nvec_out, typename Param,
ComputeType (*OP)(ComputeType, const ParamOP&)> ComputeType (*OP)(ComputeType, const ParamOP&)>
__global__ void __global__ void
__launch_bounds__(cast_transpose_num_threads) __launch_bounds__(cast_transpose_num_threads)
cast_transpose_dbias_dact_kernel_notaligned(const Param param, cast_transpose_fused_kernel_notaligned(const Param param,
const size_t row_length, const size_t row_length,
const size_t num_rows, const size_t num_rows,
const size_t num_tiles) { const size_t num_tiles) {
...@@ -810,26 +399,24 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param, ...@@ -810,26 +399,24 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param,
(nvec_in * THREADS_PER_WARP); (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile; warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = param.input + (tile_id_x * nvec_in + const size_t tile_offset_out = (tile_id_x * nvec_in + tile_id_y * row_length * nvec_out) *
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType2 * const my_act_input_tile = param.act_input +
(tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP; THREADS_PER_WARP;
OType * const my_output_c_tile = param.output_c + (tile_id_x * nvec_in + const size_t tile_offset_in = (tile_id_y * nvec_out + tile_id_x * num_rows * nvec_in) *
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_t_tile = param.output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP; THREADS_PER_WARP;
const IType * const my_input_tile = param.input + tile_offset_out;
const IType2 * const my_act_input_tile = param.act_input + tile_offset_out;
OType * const my_output_c_tile = param.output_c + tile_offset_out;
OType * const my_output_t_tile = param.output_t + tile_offset_in;
CType * const my_partial_dbias_tile = param.workspace + CType * const my_partial_dbias_tile = param.workspace +
(tile_id_x * (nvec_in * THREADS_PER_WARP) + (tile_id_x * (nvec_in * THREADS_PER_WARP) + tile_id_y * row_length);
tile_id_y * row_length);
const size_t stride = row_length / nvec_in; const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out; const size_t output_stride = num_rows / nvec_out;
...@@ -860,23 +447,27 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param, ...@@ -860,23 +447,27 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param,
CType max = 0; CType max = 0;
const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1; const CType scale = param.scale_ptr != nullptr ? *param.scale_ptr : 1;
if constexpr (IS_DBIAS) {
partial_dbias.clear(); partial_dbias.clear();
}
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height; warp_id_in_tile * n_iterations < tile_height;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) { if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); const size_t ld_offset = current_stride + my_place + stride * i;
act_in[0][i].load_from(my_act_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, ld_offset);
act_in[0][i].load_from(my_act_input_tile, ld_offset);
} else { } else {
in[0][i].clear(); in[0][i].clear();
act_in[0][i].clear(); act_in[0][i].clear();
} }
} }
} }
#pragma unroll
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride + my_place; const size_t current_place = current_stride + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
...@@ -884,14 +475,12 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param, ...@@ -884,14 +475,12 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param,
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
const bool valid_load = my_place_in < tile_length && const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height; warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) { if (valid_load) {
in[current_in][j].load_from(my_input_tile, const size_t ld_offset = current_stride + my_place_in + stride*(nvec_out + j);
current_stride + my_place_in + stride * (nvec_out + j)); in[current_in][j].load_from(my_input_tile, ld_offset);
act_in[current_in][j].load_from(my_act_input_tile, act_in[current_in][j].load_from(my_act_input_tile, ld_offset);
current_stride + my_place_in +
stride * (nvec_out + j));
} else { } else {
in[current_in][j].clear(); in[current_in][j].clear();
act_in[current_in][j].clear(); act_in[current_in][j].clear();
...@@ -899,26 +488,36 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param, ...@@ -899,26 +488,36 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param,
} }
} }
CVec after_dact[nvec_out]; // NOLINT(*) CVec after_dact[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP(act_in[current_in ^ 1][j].data.elt[k], {}) * if constexpr (IS_DACT) {
CType(in[current_in ^ 1][j].data.elt[k]); after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
OP(act_in[current_in ^ 1][j].data.elt[k], {});
} else {
after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]);
}
} }
} }
OVec out_trans[nvec_in]; // NOLINT(*) OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length && const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height; warp_id_in_tile * n_iterations + i < tile_height;
if constexpr (IS_DBIAS) {
const size_t dbias_shfl_src_lane =
(my_id_in_warp + i + warp_id_in_tile * n_iterations) % THREADS_PER_WARP;
cast_and_transpose_regs_partial_dbias<false>(after_dact, out_trans, cast_and_transpose_regs_partial_dbias<false>(after_dact, out_trans,
partial_dbias, my_output_c_tile, partial_dbias, my_output_c_tile,
current_place, stride, max, scale, current_place, stride, max, scale,
(my_id_in_warp + i + dbias_shfl_src_lane,
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP,
valid_store); valid_store);
} else {
cast_and_transpose_regs<false>(after_dact, out_trans, my_output_c_tile,
current_place, stride, max, scale, valid_store);
}
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space[i][j].data.vec = out_trans[j].data.vec; out_space[i][j].data.vec = out_trans[j].data.vec;
} }
...@@ -927,7 +526,7 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param, ...@@ -927,7 +526,7 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param,
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i];
...@@ -946,39 +545,172 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param, ...@@ -946,39 +545,172 @@ cast_transpose_dbias_dact_kernel_notaligned(const Param param,
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in; current_stride += output_stride * nvec_in;
} }
__syncthreads(); __syncthreads();
}
if constexpr (IS_DBIAS) {
my_dbias_scratch[threadIdx.x] = partial_dbias;
__syncthreads();
// TODO(ptredak): check if the regular reduction is better
if (warp_id_in_tile == 0) {
#pragma unroll
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
}
}
if (my_id_in_warp < tile_length) {
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp);
}
}
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (param.amax != nullptr) {
atomicMaxFloat(param.amax, max);
}
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
void cast_transpose_fused(const Tensor &input,
const Tensor &act_input,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.data.dtype,
"DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length },
"Wrong shape of DBias.");
}
if constexpr (IS_DACT) {
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
/* dact fusion kernel uses more registers */
constexpr int load_size = (IS_DACT ? 4 : 8);
constexpr int store_size = (IS_DACT ? 4 : 8);
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = load_size / itype_size;
constexpr int nvec_out = store_size / otype_size;
if constexpr (IS_DBIAS) {
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output,
workspace, nvec_out);
return;
}
} }
my_dbias_scratch[threadIdx.x] = partial_dbias; CheckInputTensor(input, "cast_transpose_fused_input");
__syncthreads(); CheckOutputTensor(*cast_output, "cast_output");
// TODO(ptredak): check if the regular reduction is better CheckOutputTensor(*transposed_output, "transposed_output");
if (warp_id_in_tile == 0) { if constexpr (IS_DBIAS) {
#pragma unroll CheckOutputTensor(*dbias, "dbias");
for (unsigned int i = 1; i < n_warps_per_tile; ++i) {
CVec tmp = my_dbias_scratch[threadIdx.x + i * THREADS_PER_WARP];
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
partial_dbias.data.elt[j] += tmp.data.elt[j];
} }
if constexpr (IS_DACT) {
CheckInputTensor(act_input, "act_input");
} }
if (my_id_in_warp < tile_length) { NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
partial_dbias.store_to(my_partial_dbias_tile, my_id_in_warp); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles =
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
// using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if constexpr (IS_DBIAS) {
constexpr size_t shared_size_dbias =
cast_transpose_num_threads * sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias);
}
using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
if constexpr (IS_DBIAS) {
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
} }
if constexpr (IS_DACT) {
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
} }
/* warp tile amax reduce*/ if (full_tile) {
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id); cudaFuncSetAttribute(
cast_transpose_fused_kernel
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_fused_kernel
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>
(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_fused_kernel_notaligned
<IS_DBIAS, IS_DACT, ComputeType, Empty, nvec_in, nvec_out, Param, OP>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>
(param, row_length, num_rows, n_tiles);
}
if (threadIdx.x == 0) { if constexpr (IS_DBIAS) {
static_assert(std::is_same<CType, float>::value); reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
if (param.amax != nullptr) atomicMaxFloat(param.amax, max);
} }
); // NOLINT(*)
); // NOLINT(*)
} }
template <int nvec_in, int nvec_out, template <int nvec_in, int nvec_out,
typename CType, typename IType, typename OType, typename CType, typename IType, typename OType, typename ParamOP,
typename ParamOP,
CType (*OP1)(CType, const ParamOP&), CType (*OP1)(CType, const ParamOP&),
CType (*OP2)(CType, const ParamOP&)> CType (*OP2)(CType, const ParamOP&)>
__global__ void __global__ void
...@@ -1004,7 +736,10 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -1004,7 +736,10 @@ dgated_act_cast_transpose_kernel(const IType * const input,
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile; warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return; if (tile_id >= num_tiles) {
return;
}
const size_t tile_id_x = tile_id % num_tiles_x; const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x; const size_t tile_id_y = tile_id / num_tiles_x;
...@@ -1051,19 +786,19 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -1051,19 +786,19 @@ dgated_act_cast_transpose_kernel(const IType * const input,
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2; size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
CType max = 0; CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place; const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2; const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) { if (i < n_iterations - 1) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile, in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j)); current_stride + my_place_in + stride * (nvec_out + j));
...@@ -1075,9 +810,9 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -1075,9 +810,9 @@ dgated_act_cast_transpose_kernel(const IType * const input,
} }
CVec after_dact[nvec_out]; // NOLINT(*) CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*) CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) * CType(in[current_in ^ 1][j].data.elt[k]) *
...@@ -1092,7 +827,7 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -1092,7 +827,7 @@ dgated_act_cast_transpose_kernel(const IType * const input,
OVec out_trans_1[nvec_in]; // NOLINT(*) OVec out_trans_1[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(after_dgate, out_trans_1, my_output_c_tile_1, cast_and_transpose_regs<true>(after_dgate, out_trans_1, my_output_c_tile_1,
current_place, stride2, max, scale, true); current_place, stride2, max, scale, true);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec; out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec; out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
...@@ -1103,7 +838,7 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -1103,7 +838,7 @@ dgated_act_cast_transpose_kernel(const IType * const input,
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i];
...@@ -1120,7 +855,7 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -1120,7 +855,7 @@ dgated_act_cast_transpose_kernel(const IType * const input,
current_stride += output_stride * nvec_in; current_stride += output_stride * nvec_in;
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i]; j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
...@@ -1144,8 +879,12 @@ dgated_act_cast_transpose_kernel(const IType * const input, ...@@ -1144,8 +879,12 @@ dgated_act_cast_transpose_kernel(const IType * const input,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max); if (amax != nullptr) {
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale); atomicMaxFloat(amax, max);
}
if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
}
} }
} }
...@@ -1235,12 +974,15 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1235,12 +974,15 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
{ {
const bool valid_load = my_place < tile_length && const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height; warp_id_in_tile * n_iterations < tile_height;
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) { for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) { if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); in[0][i].load_from(my_input_tile,
act_in[0][i].load_from(my_act_input_tile, current_stride2 + my_place + stride2 * i); current_stride + my_place + stride * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i); act_in[0][i].load_from(my_act_input_tile,
current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile,
current_stride2 + my_place + stride2 * i);
} else { } else {
in[0][i].clear(); in[0][i].clear();
act_in[0][i].clear(); act_in[0][i].clear();
...@@ -1248,7 +990,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1248,7 +990,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
} }
} }
#pragma unroll #pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) { for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place; const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
...@@ -1257,7 +999,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1257,7 +999,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
{ {
const bool valid_load = my_place_in < tile_length && const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height; warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) { if (valid_load) {
in[current_in][j].load_from(my_input_tile, in[current_in][j].load_from(my_input_tile,
...@@ -1276,9 +1018,9 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1276,9 +1018,9 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
CVec after_dact[nvec_out]; // NOLINT(*) CVec after_dact[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*) CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) { for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll #pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) { for (unsigned int k = 0; k < nvec_in; ++k) {
after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) * after_dact[j].data.elt[k] = OP1(act_in[current_in ^ 1][j].data.elt[k], {}) *
CType(in[current_in ^ 1][j].data.elt[k]) * CType(in[current_in ^ 1][j].data.elt[k]) *
...@@ -1295,7 +1037,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1295,7 +1037,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
current_place, stride2, max, scale, valid_store); current_place, stride2, max, scale, valid_store);
cast_and_transpose_regs<false>(after_dgate, out_trans_1, my_output_c_tile_1, cast_and_transpose_regs<false>(after_dgate, out_trans_1, my_output_c_tile_1,
current_place, stride2, max, scale, valid_store); current_place, stride2, max, scale, valid_store);
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) { for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec; out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec; out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
...@@ -1306,7 +1048,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1306,7 +1048,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
} }
for (unsigned int i = 0; i < nvec_in; ++i) { for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i]; j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i];
...@@ -1326,7 +1068,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1326,7 +1068,7 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
current_stride += output_stride * nvec_in; current_stride += output_stride * nvec_in;
} }
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) { for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP - my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i]; j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
...@@ -1353,124 +1095,13 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input, ...@@ -1353,124 +1095,13 @@ dgated_act_cast_transpose_kernel_notaligned(const IType * const input,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value); static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max); if (amax != nullptr) {
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale); atomicMaxFloat(amax, max);
} }
} if (scale_inv != nullptr) {
reciprocal<float>(scale_inv, scale);
template <typename ComputeType, typename ParamOP,
ComputeType (*OP)(ComputeType, const ParamOP&)>
void cast_transpose_dbias_dact(const Tensor &input,
const Tensor &act_input,
Tensor *cast_output,
Tensor *transposed_output,
Tensor *dbias,
Tensor *workspace,
cudaStream_t stream) {
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2,
"T output must have 2 dimensions.");
NVTE_CHECK(input.data.shape == cast_output->data.shape,
"Input and C output must have the same shape.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{ row_length }, "Wrong shape of DBias.");
NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
/* dact fusion kernel uses more registers */
constexpr int desired_load_size_dact = 4;
constexpr int desired_store_size_dact = 4;
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dact / itype_size;
constexpr int nvec_out = desired_store_size_dact / otype_size;
if (workspace->data.dptr == nullptr) {
populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out);
return;
} }
CheckInputTensor(input, "cast_transpose_dbias_dact_input");
CheckInputTensor(act_input, "act_input");
CheckOutputTensor(*cast_output, "cast_output");
CheckOutputTensor(*transposed_output, "transposed_output");
CheckOutputTensor(*dbias, "dbias");
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
// using ComputeType = fp32;
constexpr size_t shared_size_transpose = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) *
sizeof(Vec<OutputType, nvec_out>);
constexpr size_t shared_size_dbias = cast_transpose_num_threads *
sizeof(Vec<ComputeType, nvec_in>);
static_assert(shared_size_transpose >= shared_size_dbias);
using Param = CTDBiasDGeluParam<InputType, InputType2, OutputType, ComputeType>;
Param param;
param.input = reinterpret_cast<const InputType *>(input.data.dptr);
param.act_input = reinterpret_cast<const InputType2 *>(act_input.data.dptr);
param.output_c = reinterpret_cast<OutputType *>(cast_output->data.dptr);
param.output_t = reinterpret_cast<OutputType *>(transposed_output->data.dptr);
param.scale_ptr = reinterpret_cast<const ComputeType *>(cast_output->scale.dptr);
param.amax = reinterpret_cast<ComputeType *>(cast_output->amax.dptr);
param.workspace = reinterpret_cast<ComputeType *>(workspace->data.dptr);
if (full_tile) {
cudaFuncSetAttribute(
cast_transpose_dbias_dact_kernel<ComputeType, Empty,
nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dact_kernel<ComputeType, Empty,
nvec_in, nvec_out, Param, OP>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(cast_transpose_dbias_dact_kernel_notaligned<
ComputeType, Empty,
nvec_in, nvec_out, Param, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
cast_transpose_dbias_dact_kernel_notaligned<
ComputeType, Empty,
nvec_in, nvec_out, Param, OP>
<<<n_blocks,
cast_transpose_num_threads,
shared_size_transpose,
stream>>>(param, row_length, num_rows, n_tiles);
} }
reduce_dbias<InputType>(*workspace, dbias, row_length, num_rows, nvec_out, stream);
); // NOLINT(*)
); // NOLINT(*)
} }
template <typename ComputeType, typename ParamOP, template <typename ComputeType, typename ParamOP,
...@@ -1489,8 +1120,7 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1489,8 +1120,7 @@ void dgated_act_cast_transpose(const Tensor &input,
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2, NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions.");
"T output must have 2 dimensions.");
const size_t row_length = input.data.shape[1]; const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0]; const size_t num_rows = input.data.shape[0];
...@@ -1525,27 +1155,26 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1525,27 +1155,26 @@ void dgated_act_cast_transpose(const Tensor &input,
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) * const size_t n_tiles =
DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP)); DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0; num_rows % (nvec_out * THREADS_PER_WARP) == 0;
const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(dgated_act_cast_transpose_kernel< cudaFuncSetAttribute(
nvec_in, nvec_out, dgated_act_cast_transpose_kernel
ComputeType, InputType, OutputType, <nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>,
Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout,
100); 100);
dgated_act_cast_transpose_kernel< nvec_in, nvec_out,
ComputeType, InputType, OutputType, Empty, OP1, OP2> dgated_act_cast_transpose_kernel
<<<n_blocks, <nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>
cast_transpose_num_threads, <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr), reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr), reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr), reinterpret_cast<OutputType *>(cast_output->data.dptr),
...@@ -1555,19 +1184,14 @@ void dgated_act_cast_transpose(const Tensor &input, ...@@ -1555,19 +1184,14 @@ void dgated_act_cast_transpose(const Tensor &input,
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr), reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles); row_length, num_rows, n_tiles);
} else { } else {
cudaFuncSetAttribute(dgated_act_cast_transpose_kernel_notaligned< cudaFuncSetAttribute(
nvec_in, nvec_out, dgated_act_cast_transpose_kernel_notaligned
ComputeType, InputType, OutputType, <nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>,
Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, cudaFuncAttributePreferredSharedMemoryCarveout,
100); 100);
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, dgated_act_cast_transpose_kernel_notaligned
ComputeType, InputType, OutputType, Empty, OP1, OP2> <nvec_in, nvec_out, ComputeType, InputType, OutputType, Empty, OP1, OP2>
<<<n_blocks, <<<n_blocks, cast_transpose_num_threads, shmem_size, stream>>>(
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr), reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(gated_act_input.data.dptr), reinterpret_cast<const InputType *>(gated_act_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr), reinterpret_cast<OutputType *>(cast_output->data.dptr),
...@@ -1591,7 +1215,15 @@ void nvte_cast_transpose_dbias(const NVTETensor input, ...@@ -1591,7 +1215,15 @@ void nvte_cast_transpose_dbias(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias); NVTE_API_CALL(nvte_cast_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias(*reinterpret_cast<const Tensor*>(input),
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTETensor activation_input = nullptr;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, nullptr>(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(activation_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias), reinterpret_cast<Tensor*>(dbias),
...@@ -1608,7 +1240,13 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -1608,7 +1240,13 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, dgelu<fp32, fp32>>(
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(act_input), *reinterpret_cast<const Tensor*>(act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1618,33 +1256,49 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, ...@@ -1618,33 +1256,49 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
stream); stream);
} }
void nvte_dgeglu_cast_transpose(const NVTETensor input, void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
const NVTETensor gated_act_input, const NVTETensor silu_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu_cast_transpose); NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, dgelu<fp32, fp32>, gelu<fp32, fp32>>(
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dsilu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(silu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream); stream);
} }
void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
const NVTETensor silu_input, const NVTETensor relu_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); NVTE_API_CALL(nvte_cast_transpose_dbias_drelu);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, dsilu<fp32, fp32>>(
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &drelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(silu_input), *reinterpret_cast<const Tensor*>(relu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias), reinterpret_cast<Tensor*>(dbias),
...@@ -1652,33 +1306,49 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, ...@@ -1652,33 +1306,49 @@ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input,
stream); stream);
} }
void nvte_dswiglu_cast_transpose(const NVTETensor input, void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input,
const NVTETensor swiglu_input, const NVTETensor srelu_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu_cast_transpose); NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, dsilu<fp32, fp32>, silu<fp32, fp32>>(
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dsrelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(swiglu_input), *reinterpret_cast<const Tensor*>(srelu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream); stream);
} }
void nvte_cast_transpose_dbias_drelu(const NVTETensor input, void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input,
const NVTETensor relu_input, const NVTETensor qgelu_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor dbias, NVTETensor dbias,
NVTETensor workspace, NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, drelu<fp32, fp32>>(
constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;
constexpr auto dActivation = &dqgelu<fp32, fp32>;
cast_transpose_fused<IS_DBIAS, IS_DACT, fp32, Empty, dActivation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(relu_input), *reinterpret_cast<const Tensor*>(qgelu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias), reinterpret_cast<Tensor*>(dbias),
...@@ -1686,14 +1356,18 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, ...@@ -1686,14 +1356,18 @@ void nvte_cast_transpose_dbias_drelu(const NVTETensor input,
stream); stream);
} }
void nvte_dreglu_cast_transpose(const NVTETensor input, void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu_cast_transpose); NVTE_API_CALL(nvte_dgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, drelu<fp32, fp32>, relu<fp32, fp32>>(
constexpr auto dActivation = &dgelu<fp32, fp32>;
constexpr auto Activation = &gelu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1701,33 +1375,37 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, ...@@ -1701,33 +1375,37 @@ void nvte_dreglu_cast_transpose(const NVTETensor input,
stream); stream);
} }
void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, void nvte_dswiglu_cast_transpose(const NVTETensor input,
const NVTETensor srelu_input, const NVTETensor swiglu_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); NVTE_API_CALL(nvte_dswiglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, dsrelu<fp32, fp32>>(
constexpr auto dActivation = &dsilu<fp32, fp32>;
constexpr auto Activation = &silu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(srelu_input), *reinterpret_cast<const Tensor*>(swiglu_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream); stream);
} }
void nvte_dsreglu_cast_transpose(const NVTETensor input, void nvte_dreglu_cast_transpose(const NVTETensor input,
const NVTETensor gated_act_input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu_cast_transpose); NVTE_API_CALL(nvte_dreglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, dsrelu<fp32, fp32>, srelu<fp32, fp32>>(
constexpr auto dActivation = &drelu<fp32, fp32>;
constexpr auto Activation = &relu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
...@@ -1735,22 +1413,22 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input, ...@@ -1735,22 +1413,22 @@ void nvte_dsreglu_cast_transpose(const NVTETensor input,
stream); stream);
} }
void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, void nvte_dsreglu_cast_transpose(const NVTETensor input,
const NVTETensor qgelu_input, const NVTETensor gated_act_input,
NVTETensor cast_output, NVTETensor cast_output,
NVTETensor transposed_output, NVTETensor transposed_output,
NVTETensor dbias,
NVTETensor workspace,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); NVTE_API_CALL(nvte_dsreglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
cast_transpose_dbias_dact<fp32, Empty, dqgelu<fp32, fp32>>(
constexpr auto dActivation = &dsrelu<fp32, fp32>;
constexpr auto Activation = &srelu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(qgelu_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output), reinterpret_cast<Tensor*>(transposed_output),
reinterpret_cast<Tensor*>(dbias),
reinterpret_cast<Tensor*>(workspace),
stream); stream);
} }
...@@ -1761,7 +1439,11 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, ...@@ -1761,7 +1439,11 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu_cast_transpose); NVTE_API_CALL(nvte_dqgeglu_cast_transpose);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_cast_transpose<fp32, Empty, dqgelu<fp32, fp32>, qgelu<fp32, fp32>>(
constexpr auto dActivation = &dqgelu<fp32, fp32>;
constexpr auto Activation = &qgelu<fp32, fp32>;
dgated_act_cast_transpose<fp32, Empty, dActivation, Activation>(
*reinterpret_cast<const Tensor*>(input), *reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(gated_act_input), *reinterpret_cast<const Tensor*>(gated_act_input),
reinterpret_cast<Tensor*>(cast_output), reinterpret_cast<Tensor*>(cast_output),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment