Unverified Commit eed1fa26 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Add GeGLU and the corresponding gradient kernels (#47)



* Add GeGLU and DGeGLU
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add DGeGLUCT
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update copyright year
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine shape check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Code refine
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 37cc3625
......@@ -8,7 +8,10 @@ add_executable(test_operator
test_transpose.cu
test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_gelu.cu
test_geglu.cu
test_dgeglu.cu
test_layernorm.cu
test_rmsnorm.cu
test_multi_cast_transpose.cu
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/transpose.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename CType, typename IType>
inline CType gelu(const IType val) {
CType cval = val;
return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
}
template <typename CType, typename IType>
inline CType dgelu(const IType val) {
CType cval = val;
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dgated_gelu(const IT *grad_h, const IT *input_h, const CT scale,
OT *output_c_h, OT *output_t_h, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const size_t col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad_elt = CT(grad_h[i * H + j]);
CT gelu_elt = CT(input_h[i * col + j]);
CT gate_elt = CT(input_h[i * col + H + j]);
CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);
amax = std::abs(after_dgelu) > amax ? std::abs(after_dgelu) : amax;
amax = std::abs(after_dgate) > amax ? std::abs(after_dgate) : amax;
output_c_h[i * col + j] = static_cast<OT>(scale * after_dgelu);
output_c_h[i * col + H + j] = static_cast<OT>(scale * after_dgate);
output_t_h[j * N + i] = static_cast<OT>(scale * after_dgelu);
output_t_h[(j + H) * N + i] = static_cast<OT>(scale * after_dgate);
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad({N, H}, itype);
Tensor input({N, H * 2}, itype);
Tensor output_c({N, H * 2}, otype);
Tensor output_t({H * 2, N}, otype);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N * H * 2);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N * H * 2);
nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0);
CType ref_amax;
compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
output_c.scale(), ref_output_c.get(), ref_output_t.get(),
&ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400}, {4096, 2048}, {768, 2816},
{256, 5120}, {128, 10240}, {256, 256}};
} // namespace
class DGeGLUCTTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(DGeGLUCTTestSuite, TestDGeGLUCT) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, DGeGLUCTTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<DGeGLUCTTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <type_traits>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename CType, typename IType>
inline CType gelu(const IType val) {
CType cval = val;
return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
}
template <typename CType, typename IType>
inline CType dgelu(const IType val) {
CType cval = val;
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename IT, typename OT, typename CT>
void compute_ref_dgeglu(const IT *grad_h, const IT *input_h, OT *output_h, const size_t N,
const size_t H) {
const size_t col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad_elt = CT(grad_h[i * H + j]);
CT gelu_elt = CT(input_h[i * col + j]);
CT gate_elt = CT(input_h[i * col + H + j]);
CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);
output_h[i * col + j] = OT(after_dgelu);
output_h[i * col + H + j] = OT(after_dgate);
}
}
}
template <typename IType, typename OType>
void performTestDGeGLU(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad({N, H}, itype);
Tensor input({N, H * 2}, itype);
Tensor output({N, H * 2}, otype);
fillUniform(&grad);
fillUniform(&input);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H * 2);
nvte_dgeglu(grad.data(), input.data(), output.data(), 0);
compute_ref_dgeglu<IType, OType, CType>(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
ref_output.get(), N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_dgelu", output, ref_output.get(), atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
{4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};
} // namespace
class DGeGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(DGeGLUTestSuite, TestDGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestDGeGLU<InputType, OutputType>(size.first, size.second);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, DGeGLUTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<DGeGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <type_traits>
#include "../test_common.h"
using namespace transformer_engine;
template <typename IT, typename OT, typename CT>
void compute_ref_geglu_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const int col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = CT(input_h[i * col + j]);
gelu_elt = 0.5f * gelu_elt *
(1.0f + tanhf(0.79788456F * gelu_elt * (1.0f + 0.044715f * gelu_elt * gelu_elt)));
CT gate_elt = CT(input_h[i * col + H + j]);
CT elt = gelu_elt * gate_elt;
output_h[i * H + j] = OT(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTestGEGLU(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({N, H * 2}, itype);
Tensor output({N, H}, otype);
fillUniform(&input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
nvte_geglu(input.data(), output.data(), 0);
float ref_amax;
compute_ref_geglu_cast(input.cpu_dptr<IType>(), ref_output.get(), output.scale(), &ref_amax, N,
H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
}
class GeGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(GeGLUTestSuite, TestGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGEGLU<InputType, OutputType>(size.first, size.second);););
}
namespace {
std::vector<std::pair<size_t, size_t>> test_cases = {
{4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest, GeGLUTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types), ::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<GeGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
......@@ -12,6 +12,7 @@
#include "../common.h"
#include <cstdlib>
#include <../util/vectorized_pointwise.h>
#include "../util/math.h"
namespace transformer_engine {
......@@ -51,6 +52,65 @@ void gelu_cast(const Tensor &input,
); // NOLINT(*)
}
void geglu_cast(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "geglu_input");
CheckOutputTensor(*output, "geglu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be twice than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dgeglu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dgeglu_grad");
CheckInputTensor(input, "dgeglu_input");
CheckOutputTensor(*output, "dgeglu_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be twice than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_gelu(const NVTETensor input,
......@@ -61,3 +121,23 @@ void nvte_gelu(const NVTETensor input,
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
......@@ -27,6 +27,28 @@ void nvte_gelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GeGLU of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
* It computes GELU([N, :H]) x [N, H:]
* \param[in,out] output Output tensor of shape [N, H].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GeGLU gradient.
* \param[in] grad Input tensor of shape [N, H].
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -118,6 +118,28 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* transposed_output_list,
cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
*
* This function produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] geglu_input Tensor used as input to the forward of GeGLU operation.
* Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor geglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -11,9 +11,41 @@
#include <type_traits>
#include "../utils.cuh"
#include "../common.h"
#include "../util/math.h"
namespace transformer_engine {
template <bool full_tile, int nvec_in, int nvec_out, typename IVec, typename OVec, typename CType>
inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out],
OVec (&out_trans)[nvec_in],
typename OVec::type *output_cast_tile,
const size_t current_place,
const size_t stride,
CType &max, // NOLINT(*)
const CType scale,
const bool valid_store) {
using T = typename OVec::type;
using OVecC = Vec<T, nvec_in>;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
OVecC out_cast;
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
const CType tmp = static_cast<CType>(in[i].data.elt[j]);
const T elt_o = T(scale * tmp);
out_cast.data.elt[j] = elt_o;
out_trans[j].data.elt[i] = elt_o; // thread tile transpose
__builtin_assume(max >= 0);
max = fmaxf(fabsf(tmp), max);
}
if (full_tile || valid_store) {
out_cast.store_to(output_cast_tile, current_place + stride * i);
}
}
}
template <bool full_tile, int nvec_in, int nvec_out,
typename IVec, typename OVec, typename CVec, typename CType>
inline __device__ void cast_and_transpose_regs_partial_dbias(const IVec (&in)[nvec_out],
......@@ -593,19 +625,6 @@ void cast_transpose_dbias(const Tensor &input,
); // NOLINT(*)
}
namespace {
template <typename CType, typename IType>
__device__ inline CType dgelu(const IType val) {
CType cval = val;
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
} // namespace
template <int nvec_in, int nvec_out, typename Param>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
......@@ -958,6 +977,380 @@ cast_transpose_dbias_dgelu_kernel_notaligned(const Param param,
}
}
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgeglu_cast_transpose_kernel(const IType * const input,
const IType * const gelu_input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gelu_input_tile = gelu_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gate_input_tile = gelu_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP + row_length * num_rows;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec gelu_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
const size_t stride = row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
const size_t stride2 = 2 * row_length / nvec_in;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]);
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(after_dgelu, out_trans_0, my_output_c_tile_0,
current_place, stride2, max, scale, true);
OVec out_trans_1[nvec_in]; // NOLINT(*)
cast_and_transpose_regs<true>(after_dgate, out_trans_1, my_output_c_tile_1,
current_place, stride2, max, scale, true);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
template <int nvec_in, int nvec_out, typename CType, typename IType, typename OType>
__global__ void
__launch_bounds__(cast_transpose_num_threads)
dgeglu_cast_transpose_kernel_notaligned(const IType * const input,
const IType * const gelu_input,
OType * const output_c,
OType * const output_t,
const CType * const scale_ptr,
CType * const amax,
CType * const scale_inv,
const size_t row_length,
const size_t num_rows,
const size_t num_tiles) {
using IVec = Vec<IType, nvec_in>;
using OVec = Vec<OType, nvec_out>;
using CVec = Vec<CType, nvec_in>;
extern __shared__ char scratch[];
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP;
const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) /
(nvec_in * THREADS_PER_WARP);
const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) +
warp_id / n_warps_per_tile;
if (tile_id >= num_tiles) return;
const size_t tile_id_x = tile_id % num_tiles_x;
const size_t tile_id_y = tile_id / num_tiles_x;
const IType * const my_input_tile = input + (tile_id_x * nvec_in +
tile_id_y * row_length * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gelu_input_tile = gelu_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
const IType * const my_gate_input_tile = gelu_input + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_c_tile_0 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP;
OType * const my_output_c_tile_1 = output_c + (tile_id_x * nvec_in +
tile_id_y * row_length * 2 * nvec_out) *
THREADS_PER_WARP + row_length;
OType * const my_output_t_tile_0 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP;
OType * const my_output_t_tile_1 = output_t + (tile_id_y * nvec_out +
tile_id_x * num_rows * nvec_in) *
THREADS_PER_WARP + row_length * num_rows;
const size_t stride = row_length / nvec_in;
const size_t stride2 = 2 * row_length / nvec_in;
const size_t output_stride = num_rows / nvec_out;
const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP;
const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP;
const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_length_rest;
const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP
: row_height_rest;
OVec * const my_scratch = reinterpret_cast<OVec*>(scratch) +
(my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) *
(THREADS_PER_WARP + 1);
IVec in[2][nvec_out];
IVec gelu_in[2][nvec_out];
IVec gate_in[2][nvec_out];
const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile;
constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile;
OVec out_space_0[n_iterations][nvec_in];
OVec out_space_1[n_iterations][nvec_in];
size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride;
size_t current_stride2 = warp_id_in_tile * n_iterations * nvec_out * stride2;
unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP -
warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
CType max = 0;
const CType scale = scale_ptr != nullptr ? *scale_ptr : 1;
{
const bool valid_load = my_place < tile_length &&
warp_id_in_tile * n_iterations < tile_height;
#pragma unroll
for (unsigned int i = 0; i < nvec_out; ++i) {
if (valid_load) {
in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i);
gelu_in[0][i].load_from(my_gelu_input_tile, current_stride2 + my_place + stride2 * i);
gate_in[0][i].load_from(my_gate_input_tile, current_stride2 + my_place + stride2 * i);
} else {
in[0][i].clear();
gelu_in[0][i].clear();
gate_in[0][i].clear();
}
}
}
#pragma unroll
for (unsigned int i = 0; i < n_iterations; ++i) {
const size_t current_place = current_stride2 + my_place;
const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
const unsigned int current_in = (i + 1) % 2;
if (i < n_iterations - 1) {
{
const bool valid_load = my_place_in < tile_length &&
warp_id_in_tile * n_iterations + i + 1 < tile_height;
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
if (valid_load) {
in[current_in][j].load_from(my_input_tile,
current_stride + my_place_in + stride * (nvec_out + j));
gelu_in[current_in][j].load_from(my_gelu_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
gate_in[current_in][j].load_from(my_gate_input_tile,
current_stride2 + my_place_in + stride2 * (nvec_out + j));
} else {
in[current_in][j].clear();
gelu_in[current_in][j].clear();
gate_in[current_in][j].clear();
}
}
}
}
CVec after_dgelu[nvec_out]; // NOLINT(*)
CVec after_dgate[nvec_out]; // NOLINT(*)
#pragma unroll
for (unsigned int j = 0; j < nvec_out; ++j) {
#pragma unroll
for (unsigned int k = 0; k < nvec_in; ++k) {
after_dgelu[j].data.elt[k] = dgelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]) *
CType(in[current_in ^ 1][j].data.elt[k]) *
CType(gate_in[current_in ^ 1][j].data.elt[k]);
after_dgate[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) *
gelu<CType>(gelu_in[current_in ^ 1][j].data.elt[k]);
}
}
OVec out_trans_0[nvec_in]; // NOLINT(*)
OVec out_trans_1[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
cast_and_transpose_regs<false>(after_dgelu, out_trans_0, my_output_c_tile_0,
current_place, stride2, max, scale, valid_store);
cast_and_transpose_regs<false>(after_dgate, out_trans_1, my_output_c_tile_1,
current_place, stride2, max, scale, valid_store);
#pragma unroll
for (unsigned int j = 0; j < nvec_in; ++j) {
out_space_0[i][j].data.vec = out_trans_0[j].data.vec;
out_space_1[i][j].data.vec = out_trans_1[j].data.vec;
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += nvec_out * stride;
current_stride2 += nvec_out * stride2;
}
for (unsigned int i = 0; i < nvec_in; ++i) {
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_0[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_0,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
#pragma unroll
for (unsigned int j = 0; j < n_iterations; ++j) {
my_scratch[(my_id_in_warp + THREADS_PER_WARP -
j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space_1[j][i];
}
__syncthreads();
my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) %
THREADS_PER_WARP;
current_stride = i * output_stride +
warp_id_in_tile * n_iterations * output_stride * nvec_in;
for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) {
const bool valid_store = my_place < tile_height;
if (valid_store) {
my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile_1,
current_stride + my_place);
}
my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP;
current_stride += output_stride * nvec_in;
}
__syncthreads();
}
/* warp tile amax reduce*/
max = reduce_max<cast_transpose_num_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
if (amax != nullptr) atomicMaxFloat(amax, max);
if (scale_inv != nullptr) reciprocal<float>(scale_inv, scale);
}
}
void cast_transpose_dbias_dgelu(const Tensor &input,
const Tensor &gelu_input,
Tensor *cast_output,
......@@ -1066,6 +1459,105 @@ void cast_transpose_dbias_dgelu(const Tensor &input,
); // NOLINT(*)
}
void dgeglu_cast_transpose(const Tensor &input,
const Tensor &geglu_input,
Tensor *cast_output,
Tensor *transposed_output,
cudaStream_t stream) {
CheckInputTensor(input, "dgeglu_cast_transpose_input");
CheckInputTensor(geglu_input, "dgeglu_cast_transpose_geglu_input");
CheckOutputTensor(*cast_output, "dgeglu_cast_transpose_cast_output");
CheckOutputTensor(*transposed_output, "dgeglu_cast_transpose_transposed_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(geglu_input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions.");
NVTE_CHECK(transposed_output->data.shape.size() == 2,
"T output must have 2 dimensions.");
const size_t row_length = input.data.shape[1];
const size_t num_rows = input.data.shape[0];
NVTE_CHECK(geglu_input.data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(geglu_input.data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output.");
NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output.");
NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output.");
NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output.");
NVTE_CHECK(input.data.dtype == geglu_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype,
"C and T outputs need to have the same type.");
NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr,
"C and T outputs need to share amax tensor.");
NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr,
"C and T outputs need to share scale tensor.");
NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr,
"C and T outputs need to share scale inverse tensor.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType,
using InputType2 = InputType;
/* dgelu fusion kernel uses more registers */
constexpr int desired_load_size_dgelu = 4;
constexpr int desired_store_size_dgelu = 4;
constexpr int itype_size = sizeof(InputType);
constexpr int otype_size = sizeof(OutputType);
constexpr int nvec_in = desired_load_size_dgelu / itype_size;
constexpr int nvec_out = desired_store_size_dgelu / otype_size;
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
const size_t n_tiles = DIVUP(row_length, static_cast<size_t>(nvec_in * THREADS_PER_WARP)) *
DIVUP(num_rows, static_cast<size_t>(nvec_out * THREADS_PER_WARP));
const size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP;
const size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block);
const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 &&
num_rows % (nvec_out * THREADS_PER_WARP) == 0;
if (full_tile) {
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel<nvec_in, nvec_out, fp32,
InputType, OutputType>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgeglu_cast_transpose_kernel<nvec_in, nvec_out, fp32, InputType, OutputType>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(geglu_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
} else {
cudaFuncSetAttribute(dgeglu_cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32,
InputType, OutputType>,
cudaFuncAttributePreferredSharedMemoryCarveout,
100);
dgeglu_cast_transpose_kernel_notaligned<nvec_in, nvec_out, fp32, InputType, OutputType>
<<<n_blocks,
cast_transpose_num_threads,
cast_transpose_num_threads / n_warps_per_tile *
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>),
stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const InputType *>(geglu_input.data.dptr),
reinterpret_cast<OutputType *>(cast_output->data.dptr),
reinterpret_cast<OutputType *>(transposed_output->data.dptr),
reinterpret_cast<const fp32 *>(cast_output->scale.dptr),
reinterpret_cast<fp32 *>(cast_output->amax.dptr),
reinterpret_cast<fp32 *>(cast_output->scale_inv.dptr),
row_length, num_rows, n_tiles);
}
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine
void nvte_cast_transpose_dbias(const NVTETensor input,
......@@ -1099,3 +1591,16 @@ void nvte_cast_transpose_dbias_dgelu(const NVTETensor input,
reinterpret_cast<Tensor*>(workspace),
stream);
}
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor geglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream) {
using namespace transformer_engine;
dgeglu_cast_transpose(*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(geglu_input),
reinterpret_cast<Tensor*>(cast_output),
reinterpret_cast<Tensor*>(transposed_output),
stream);
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
namespace transformer_engine {
namespace {
template <typename OType, typename IType>
__device__ inline OType gelu(const IType val) {
const float cval = val;
return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
}
template <typename OType, typename IType>
__device__ inline OType dgelu(const IType val) {
const float cval = val;
const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
} // namespace
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
......@@ -258,40 +258,28 @@ inline int CalcAlignment(const void *ptr, const int size) {
\param lead_dim Leading dimension of the tensors.
\param other_dim The size of the other dimensions of the tensors.
\param nvec Length of the vector.
\param inputs Inputs to the operator.
\param outputs Outputs of the operator.
\param ptrs Inputs and Outputs to the operator.
*/
template <typename InputType, typename OutputType>
template <typename... T>
Alignment CheckAlignment(const size_t lead_dim,
const int nvec,
const InputType *input,
const OutputType *output) {
int align = -1;
if (input != nullptr) {
int new_align = CalcAlignment(input, sizeof(InputType) * nvec);
if (align == -1) {
align = new_align;
} else {
if (align != new_align) {
return Alignment::DIFFERENT;
}
}
}
if (output != nullptr) {
int new_align = CalcAlignment(output, sizeof(OutputType) * nvec);
if (align == -1) {
align = new_align;
} else {
if (align != new_align) {
return Alignment::DIFFERENT;
}
}
const T... ptrs
) {
std::vector<int> alignments;
alignments.reserve(sizeof...(T));
// calculate the alignments of all ptrs and store them into alignments
(..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec)));
bool all_same = std::all_of(alignments.cbegin(), alignments.cend(),
[alignments](int val) {return val == alignments.front();});
if (!all_same) {
return Alignment::DIFFERENT;
}
if ((align == 0) &&
(lead_dim % nvec == 0)) {
if (alignments.front() == 0 &&
lead_dim % nvec == 0) {
// all alignment are 0
return Alignment::SAME_ALIGNED;
} else {
return Alignment::SAME_UNALIGNED;
......@@ -341,6 +329,191 @@ void VectorizedUnaryKernelLauncher(const InputType *input,
}
}
template <int nvec, bool aligned,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void gated_act_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax,
const size_t m,
const size_t n,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> loader0(input + id_y * n * 2, n);
VectorizedLoader<InputType, nvec, aligned> loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer(output + id_y * n, n);
ComputeType max = 0;
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
loader0.load(id_x, n);
loader1.load(id_x, n);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
ComputeType temp = static_cast<ComputeType>(Activation(val) * val2);
if constexpr (is_fp8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
temp = temp * s;
}
storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp));
}
storer.store(id_x, n);
if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
}
}
template <int nvec,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
typename InputType,
typename OutputType>
void GatedActivationKernelLauncher(const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *scale_inv,
fp32 *amax,
const size_t m,
const size_t n,
cudaStream_t stream) {
if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements * m, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, n);
break;
}
}
}
}
template <int nvec, bool aligned,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
ComputeType (*Dactivation)(ComputeType),
typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void dgated_act_kernel(const InputType *grad,
const InputType *input,
OutputType *output,
const size_t m,
const size_t n,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> grad_loader(grad + id_y * n, n);
VectorizedLoader<InputType, nvec, aligned> input_loader0(input + id_y * n * 2, n);
VectorizedLoader<InputType, nvec, aligned> input_loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer0(output + id_y * n * 2, n);
VectorizedStorer<OutputType, nvec, aligned> storer1(output + id_y * n * 2 + n, n);
grad_loader.load(id_x, n);
input_loader0.load(id_x, n);
input_loader1.load(id_x, n);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]);
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
ComputeType after_dgelu = Dactivation(gelu_in) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in);
storer0.separate()[i] = static_cast<OutputType>(after_dgelu);
storer1.separate()[i] = static_cast<OutputType>(after_dgate);
}
storer0.store(id_x, n);
storer1.store(id_x, n);
}
}
template <int nvec,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
ComputeType (*Dactivation)(ComputeType),
typename InputType,
typename OutputType>
void DGatedActivationKernelLauncher(const InputType *grad,
const InputType *input,
OutputType *output,
const size_t m,
const size_t n,
cudaStream_t stream) {
if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec,
sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements * m, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) {
case Alignment::SAME_ALIGNED:
dgated_act_kernel<nvec, true, ComputeType, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
dgated_act_kernel<nvec, false, ComputeType, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
dgated_act_kernel<1, true, ComputeType, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, n);
break;
}
}
}
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment