Unverified Commit eed1fa26 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Add GeGLU and the corresponding gradient kernels (#47)



* Add GeGLU and DGeGLU
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add DGeGLUCT
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update copyright year
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine shape check
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Code refine
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 37cc3625
...@@ -8,7 +8,10 @@ add_executable(test_operator ...@@ -8,7 +8,10 @@ add_executable(test_operator
test_transpose.cu test_transpose.cu
test_cast_transpose_dbias.cu test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_gelu.cu test_gelu.cu
test_geglu.cu
test_dgeglu.cu
test_layernorm.cu test_layernorm.cu
test_rmsnorm.cu test_rmsnorm.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/transpose.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename CType, typename IType>
inline CType gelu(const IType val) {
CType cval = val;
return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
}
template <typename CType, typename IType>
inline CType dgelu(const IType val) {
CType cval = val;
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename IT, typename OT, typename CT>
void compute_ref_cast_transpose_dgated_gelu(const IT *grad_h, const IT *input_h, const CT scale,
OT *output_c_h, OT *output_t_h, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const size_t col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad_elt = CT(grad_h[i * H + j]);
CT gelu_elt = CT(input_h[i * col + j]);
CT gate_elt = CT(input_h[i * col + H + j]);
CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);
amax = std::abs(after_dgelu) > amax ? std::abs(after_dgelu) : amax;
amax = std::abs(after_dgate) > amax ? std::abs(after_dgate) : amax;
output_c_h[i * col + j] = static_cast<OT>(scale * after_dgelu);
output_c_h[i * col + H + j] = static_cast<OT>(scale * after_dgate);
output_t_h[j * N + i] = static_cast<OT>(scale * after_dgelu);
output_t_h[(j + H) * N + i] = static_cast<OT>(scale * after_dgate);
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad({N, H}, itype);
Tensor input({N, H * 2}, itype);
Tensor output_c({N, H * 2}, otype);
Tensor output_t({H * 2, N}, otype);
fillUniform(&grad);
fillUniform(&input);
setRandomScale(&output_c);
output_t.shareFP8Meta(output_c);
std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N * H * 2);
std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N * H * 2);
nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0);
CType ref_amax;
compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
output_c.scale(), ref_output_c.get(), ref_output_t.get(),
&ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (isFp8Type(otype)) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output_c.scale();
compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400}, {4096, 2048}, {768, 2816},
{256, 5120}, {128, 10240}, {256, 256}};
} // namespace
class DGeGLUCTTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(DGeGLUCTTestSuite, TestDGeGLUCT) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, DGeGLUCTTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<DGeGLUCTTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <type_traits>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
template <typename CType, typename IType>
inline CType gelu(const IType val) {
CType cval = val;
return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
}
template <typename CType, typename IType>
inline CType dgelu(const IType val) {
CType cval = val;
const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
template <typename IT, typename OT, typename CT>
void compute_ref_dgeglu(const IT *grad_h, const IT *input_h, OT *output_h, const size_t N,
const size_t H) {
const size_t col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad_elt = CT(grad_h[i * H + j]);
CT gelu_elt = CT(input_h[i * col + j]);
CT gate_elt = CT(input_h[i * col + H + j]);
CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);
output_h[i * col + j] = OT(after_dgelu);
output_h[i * col + H + j] = OT(after_dgate);
}
}
}
template <typename IType, typename OType>
void performTestDGeGLU(const size_t N, const size_t H) {
using namespace test;
using CType = fp32;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor grad({N, H}, itype);
Tensor input({N, H * 2}, itype);
Tensor output({N, H * 2}, otype);
fillUniform(&grad);
fillUniform(&input);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H * 2);
nvte_dgeglu(grad.data(), input.data(), output.data(), 0);
compute_ref_dgeglu<IType, OType, CType>(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
ref_output.get(), N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(otype);
compareResults("output_dgelu", output, ref_output.get(), atol, rtol);
}
std::vector<std::pair<size_t, size_t>> test_cases = {
{4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};
} // namespace
class DGeGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(DGeGLUTestSuite, TestDGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestDGeGLU<InputType, OutputType>(size.first, size.second);););
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest, DGeGLUTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<DGeGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <type_traits>
#include "../test_common.h"
using namespace transformer_engine;
template <typename IT, typename OT, typename CT>
void compute_ref_geglu_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const int col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = CT(input_h[i * col + j]);
gelu_elt = 0.5f * gelu_elt *
(1.0f + tanhf(0.79788456F * gelu_elt * (1.0f + 0.044715f * gelu_elt * gelu_elt)));
CT gate_elt = CT(input_h[i * col + H + j]);
CT elt = gelu_elt * gate_elt;
output_h[i * H + j] = OT(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTestGEGLU(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({N, H * 2}, itype);
Tensor output({N, H}, otype);
fillUniform(&input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
nvte_geglu(input.data(), output.data(), 0);
float ref_amax;
compute_ref_geglu_cast(input.cpu_dptr<IType>(), ref_output.get(), output.scale(), &ref_amax, N,
H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
float ref_scale_inv = 1.f / output.scale();
compareResults("scale_inv", output.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
}
class GeGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(GeGLUTestSuite, TestGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGEGLU<InputType, OutputType>(size.first, size.second);););
}
namespace {
std::vector<std::pair<size_t, size_t>> test_cases = {
{4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest, GeGLUTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types), ::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<GeGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "../common.h" #include "../common.h"
#include <cstdlib> #include <cstdlib>
#include <../util/vectorized_pointwise.h> #include <../util/vectorized_pointwise.h>
#include "../util/math.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -51,6 +52,65 @@ void gelu_cast(const Tensor &input, ...@@ -51,6 +52,65 @@ void gelu_cast(const Tensor &input,
); // NOLINT(*) ); // NOLINT(*)
} }
void geglu_cast(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "geglu_input");
CheckOutputTensor(*output, "geglu_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(input.data.shape[0] == output->data.shape[0],
"Input shape[0] must be equal to output shape[0].");
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be twice than output shape[1].");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
GatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>>(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->scale_inv.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
output->data.shape[0],
output->data.shape[1],
stream);
); // NOLINT(*)
); // NOLINT(*)
}
void dgeglu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(grad, "dgeglu_grad");
CheckInputTensor(input, "dgeglu_input");
CheckOutputTensor(*output, "dgeglu_output");
NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions.");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions.");
NVTE_CHECK(output->data.shape[0] == grad.data.shape[0],
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be twice than grad shape[1].");
NVTE_CHECK(input.data.shape == output->data.shape,
"Input and output shapes must match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
DGatedActivationKernelLauncher<nvec, fp32, gelu<fp32, fp32>, dgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
grad.data.shape[0],
grad.data.shape[1],
stream);
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace transformer_engine } // namespace transformer_engine
void nvte_gelu(const NVTETensor input, void nvte_gelu(const NVTETensor input,
...@@ -61,3 +121,23 @@ void nvte_gelu(const NVTETensor input, ...@@ -61,3 +121,23 @@ void nvte_gelu(const NVTETensor input,
reinterpret_cast<Tensor*>(output), reinterpret_cast<Tensor*>(output),
stream); stream);
} }
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
using namespace transformer_engine;
geglu_cast(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
using namespace transformer_engine;
dgeglu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
...@@ -27,6 +27,28 @@ void nvte_gelu(const NVTETensor input, ...@@ -27,6 +27,28 @@ void nvte_gelu(const NVTETensor input,
NVTETensor output, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute GeGLU of the input.
*
* \param[in] input Input tensor of shape [N, H * 2].
* It computes GELU([N, :H]) x [N, H:]
* \param[in,out] output Output tensor of shape [N, H].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
/*! \brief Compute GeGLU gradient.
* \param[in] grad Input tensor of shape [N, H].
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -118,6 +118,28 @@ void nvte_multi_cast_transpose(size_t num_tensors, ...@@ -118,6 +118,28 @@ void nvte_multi_cast_transpose(size_t num_tensors,
NVTETensor* transposed_output_list, NVTETensor* transposed_output_list,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output.
*
* This function produces 2 results:
* - `cast_output` is the result of the cast
* - `transposed_output` is the transposed result of the cast.
*
* Calling this function with workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] geglu_input Tensor used as input to the forward of GeGLU operation.
* Shape [N, H * 2].
* \param[in,out] cast_output Result of the cast. Shape: [N, H * 2].
* \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dgeglu_cast_transpose(const NVTETensor input,
const NVTETensor geglu_input,
NVTETensor cast_output,
NVTETensor transposed_output,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
namespace transformer_engine {
namespace {
template <typename OType, typename IType>
__device__ inline OType gelu(const IType val) {
const float cval = val;
return cval * (0.5F + 0.5F * tanhf(cval * (0.79788456F + 0.03567741F * cval * cval)));
}
template <typename OType, typename IType>
__device__ inline OType dgelu(const IType val) {
const float cval = val;
const float tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
return 0.5f * cval * ((1.f - tanh_out * tanh_out) *
(0.79788456f + 0.1070322243f * cval * cval)) +
0.5f * (1.f + tanh_out);
}
} // namespace
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_MATH_H_
...@@ -258,40 +258,28 @@ inline int CalcAlignment(const void *ptr, const int size) { ...@@ -258,40 +258,28 @@ inline int CalcAlignment(const void *ptr, const int size) {
\param lead_dim Leading dimension of the tensors. \param lead_dim Leading dimension of the tensors.
\param other_dim The size of the other dimensions of the tensors. \param other_dim The size of the other dimensions of the tensors.
\param nvec Length of the vector. \param nvec Length of the vector.
\param inputs Inputs to the operator. \param ptrs Inputs and Outputs to the operator.
\param outputs Outputs of the operator.
*/ */
template <typename InputType, typename OutputType> template <typename... T>
Alignment CheckAlignment(const size_t lead_dim, Alignment CheckAlignment(const size_t lead_dim,
const int nvec, const int nvec,
const InputType *input, const T... ptrs
const OutputType *output) { ) {
int align = -1; std::vector<int> alignments;
alignments.reserve(sizeof...(T));
if (input != nullptr) {
int new_align = CalcAlignment(input, sizeof(InputType) * nvec); // calculate the alignments of all ptrs and store them into alignments
if (align == -1) { (..., alignments.push_back(CalcAlignment(ptrs, sizeof(*ptrs) * nvec)));
align = new_align;
} else { bool all_same = std::all_of(alignments.cbegin(), alignments.cend(),
if (align != new_align) { [alignments](int val) {return val == alignments.front();});
return Alignment::DIFFERENT; if (!all_same) {
} return Alignment::DIFFERENT;
}
}
if (output != nullptr) {
int new_align = CalcAlignment(output, sizeof(OutputType) * nvec);
if (align == -1) {
align = new_align;
} else {
if (align != new_align) {
return Alignment::DIFFERENT;
}
}
} }
if ((align == 0) && if (alignments.front() == 0 &&
(lead_dim % nvec == 0)) { lead_dim % nvec == 0) {
// all alignment are 0
return Alignment::SAME_ALIGNED; return Alignment::SAME_ALIGNED;
} else { } else {
return Alignment::SAME_UNALIGNED; return Alignment::SAME_UNALIGNED;
...@@ -341,6 +329,191 @@ void VectorizedUnaryKernelLauncher(const InputType *input, ...@@ -341,6 +329,191 @@ void VectorizedUnaryKernelLauncher(const InputType *input,
} }
} }
template <int nvec, bool aligned,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void gated_act_kernel(const InputType *input,
OutputType *output,
const ComputeType *scale,
ComputeType *scale_inv,
ComputeType *amax,
const size_t m,
const size_t n,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> loader0(input + id_y * n * 2, n);
VectorizedLoader<InputType, nvec, aligned> loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer(output + id_y * n, n);
ComputeType max = 0;
ComputeType s = 0;
if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale;
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s);
}
}
const int warp_id = threadIdx.x / THREADS_PER_WARP;
loader0.load(id_x, n);
loader1.load(id_x, n);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
ComputeType temp = static_cast<ComputeType>(Activation(val) * val2);
if constexpr (is_fp8<OutputType>::value) {
__builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max);
temp = temp * s;
}
storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp));
}
storer.store(id_x, n);
if constexpr (is_fp8<OutputType>::value) {
/* warp tile amax reduce*/
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0 && amax != nullptr) {
static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max);
}
}
}
}
template <int nvec,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
typename InputType,
typename OutputType>
void GatedActivationKernelLauncher(const InputType *input,
OutputType *output,
const fp32 *scale,
fp32 *scale_inv,
fp32 *amax,
const size_t m,
const size_t n,
cudaStream_t stream) {
if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(input, n, nvec, sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements * m, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
switch (auto align = CheckAlignment(n, nvec, input, input + n, output)) {
case Alignment::SAME_ALIGNED:
gated_act_kernel<nvec, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
gated_act_kernel<nvec, false, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
gated_act_kernel<1, true, ComputeType, Activation><<<num_blocks, threads, 0, stream>>>(
input, output, scale, scale_inv, amax, m, n, n);
break;
}
}
}
}
template <int nvec, bool aligned,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
ComputeType (*Dactivation)(ComputeType),
typename InputType,
typename OutputType>
__launch_bounds__(unary_kernel_threads)
__global__ void dgated_act_kernel(const InputType *grad,
const InputType *input,
OutputType *output,
const size_t m,
const size_t n,
const size_t num_aligned_elements) {
const size_t M = num_aligned_elements * m;
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
const size_t id_x = tid % num_aligned_elements;
const size_t id_y = tid / num_aligned_elements;
VectorizedLoader<InputType, nvec, aligned> grad_loader(grad + id_y * n, n);
VectorizedLoader<InputType, nvec, aligned> input_loader0(input + id_y * n * 2, n);
VectorizedLoader<InputType, nvec, aligned> input_loader1(input + id_y * n * 2 + n, n);
VectorizedStorer<OutputType, nvec, aligned> storer0(output + id_y * n * 2, n);
VectorizedStorer<OutputType, nvec, aligned> storer1(output + id_y * n * 2 + n, n);
grad_loader.load(id_x, n);
input_loader0.load(id_x, n);
input_loader1.load(id_x, n);
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]);
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
ComputeType after_dgelu = Dactivation(gelu_in) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in);
storer0.separate()[i] = static_cast<OutputType>(after_dgelu);
storer1.separate()[i] = static_cast<OutputType>(after_dgate);
}
storer0.store(id_x, n);
storer1.store(id_x, n);
}
}
template <int nvec,
typename ComputeType,
ComputeType (*Activation)(ComputeType),
ComputeType (*Dactivation)(ComputeType),
typename InputType,
typename OutputType>
void DGatedActivationKernelLauncher(const InputType *grad,
const InputType *input,
OutputType *output,
const size_t m,
const size_t n,
cudaStream_t stream) {
if (m != 0 && n != 0) {
size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec,
sizeof(InputType));
constexpr size_t threads = unary_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements * m, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);
switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) {
case Alignment::SAME_ALIGNED:
dgated_act_kernel<nvec, true, ComputeType, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, num_aligned_elements);
break;
case Alignment::SAME_UNALIGNED:
dgated_act_kernel<nvec, false, ComputeType, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, num_aligned_elements);
break;
case Alignment::DIFFERENT: {
// If the pointers are aligned differently we cannot vectorize
dgated_act_kernel<1, true, ComputeType, Activation, Dactivation>
<<<num_blocks, threads, 0, stream>>>(grad, input, output, m, n, n);
break;
}
}
}
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_VECTORIZED_POINTWISE_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment