Unverified Commit 91a16a3f authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Add more C++ tests for activations (#1049)



* Added tests for silu/relu/swiglu/reglu
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added other activations/backwards and fixed dqgelu
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix 2
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Actually adding srelu and qgelu tests
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix glu backward test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Pruning unnecessary test configurations
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e113bf84
...@@ -9,8 +9,7 @@ add_executable(test_operator ...@@ -9,8 +9,7 @@ add_executable(test_operator
test_cast_transpose_dbias.cu test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu test_cast_transpose_dgeglu.cu
test_gelu.cu test_act.cu
test_geglu.cu
test_dgeglu.cu test_dgeglu.cu
test_layernorm.cu test_layernorm.cu
test_rmsnorm.cu test_rmsnorm.cu
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
// forward
float gelu(const float x) {
return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x)));
}
float silu(const float x) {
return x / (1 + expf(-x));
}
float relu(const float x) {
return x > 0 ? x : 0;
}
float srelu(const float x) {
return x > 0 ? x * x : 0;
}
float qgelu(const float x) {
return x / (1 + expf(-1.702f * x));
}
// backward
float dgelu(const float x) {
const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x));
return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) +
0.5f * (1.f + tanh_out);
}
float dsilu(const float x) {
const float sigmoid = 1.f / (1 + expf(-x));
return x * sigmoid * (1.f - sigmoid) + sigmoid;
}
float drelu(const float x) {
return x > 0.f ? 1.f : 0.f;
}
float dsrelu(const float x) {
return fmaxf(2.f * x, 0.f);
}
float dqgelu(const float x) {
const float sigmoid = 1.f / (1 + expf(-1.702f * x));
return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid;
}
} // namespace
template <float (*act)(const float), typename IT, typename OT, typename CT>
void compute_ref_act_cast(const IT *input_h,
OT *output_h,
const CT scale,
CT *amax_h,
const size_t N,
const size_t H) {
CT amax = 0.;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
elt = act(elt);
output_h[i * H + j] = static_cast<OT>(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <float (*dact)(const float), typename IT, typename OT>
void compute_ref_dact_cast(const IT *input_h,
const IT *grad_h,
OT *output_h,
const size_t N,
const size_t H) {
using CT = float;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
elt = dact(elt);
CT grad = static_cast<CT>(grad_h[i * H + j]);
output_h[i * H + j] = static_cast<OT>(grad * elt);
}
}
}
template <float (*act)(const float), typename IT, typename OT, typename CT>
void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const int col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = static_cast<CT>(input_h[i * col + j]);
gelu_elt = act(gelu_elt);
CT gate_elt = static_cast<CT>(input_h[i * col + H + j]);
CT elt = gelu_elt * gate_elt;
output_h[i * H + j] = static_cast<OT>(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <float (*dact)(const float), float (*act)(const float),
typename IT, typename OT>
void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h,
const size_t N, const size_t H) {
const int col = H * 2;
using CT = float;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT grad = static_cast<CT>(grad_h[i * H + j]);
CT gelu_elt = static_cast<CT>(input_h[i * col + j]);
CT gate_elt = static_cast<CT>(input_h[i * col + H + j]);
output_h[i * col + H + j] = static_cast<OT>(grad * act(gelu_elt));
gelu_elt = dact(gelu_elt);
CT elt = gelu_elt * gate_elt;
output_h[i * col + j] = static_cast<OT>(grad * elt);
}
}
}
template <float (*ref_act)(const float),
float (*ref_dact)(const float),
void (*nvte_act)(const NVTETensor, NVTETensor, cudaStream_t),
void (*nvte_dact)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t),
typename IType, typename OType>
void performTest(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({ N, H }, itype);
Tensor output({ N, H }, otype);
Tensor igrad({ N, H }, itype);
Tensor ograd({ N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
std::unique_ptr<IType[]> ref_igrad = std::make_unique<IType[]>(N*H);
nvte_act(input.data(), output.data(), 0);
float ref_amax;
compute_ref_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_act", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dact_cast<ref_dact>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
{
auto [atol, rtol] = getTolerances(otype);
compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol);
}
}
template <float (*ref_act)(const float),
float (*ref_dact)(const float),
void (*nvte_act)(const NVTETensor, NVTETensor, cudaStream_t),
void (*nvte_dact)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t),
typename IType, typename OType>
void performTestGLU(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({N, H * 2}, itype);
Tensor output({N, H}, otype);
Tensor igrad({ N, H * 2 }, itype);
Tensor ograd({ N, H }, itype);
fillUniform(&input);
fillUniform(&ograd);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
std::unique_ptr<IType[]> ref_igrad = std::make_unique<IType[]>(2 * N * H);
nvte_act(input.data(), output.data(), 0);
float ref_amax;
compute_ref_glu_act_cast<ref_act>(input.cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
nvte_dact(ograd.data(), input.data(), igrad.data(), 0);
compute_ref_dglu_act_cast<ref_dact, ref_act>(input.cpu_dptr<IType>(), ograd.cpu_dptr<IType>(),
ref_igrad.get(), N, H);
cudaDeviceSynchronize();
err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
{
auto [atol, rtol] = getTolerances(otype);
compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol);
}
}
class ActTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(ActTestSuite, TestGELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<gelu, dgelu, nvte_gelu, nvte_dgelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestSILU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<silu, dsilu, nvte_silu, nvte_dsilu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestRELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<relu, drelu, nvte_relu, nvte_drelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestQGELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<qgelu, dqgelu, nvte_qgelu, nvte_dqgelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestSRELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<srelu, dsrelu, nvte_srelu, nvte_dsrelu,
InputType, OutputType>(size.first, size.second);
);
);
}
TEST_P(ActTestSuite, TestGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<gelu, dgelu, nvte_geglu, nvte_dgeglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestReGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<relu, drelu, nvte_reglu, nvte_dreglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestSwiGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<silu, dsilu, nvte_swiglu, nvte_dswiglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestQGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<qgelu, dqgelu, nvte_qgeglu, nvte_dqgeglu, InputType,
OutputType>(size.first, size.second);););
}
TEST_P(ActTestSuite, TestSReGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGLU<srelu, dsrelu, nvte_sreglu, nvte_dsreglu, InputType,
OutputType>(size.first, size.second);););
}
namespace {
std::vector<std::pair<size_t, size_t>> act_test_cases = {{2048, 12288},
{768, 2816},
{256, 65536},
{65536, 128},
{256, 256},
{257, 259},
{128, 128+1}};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
ActTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(act_test_cases)),
[](const testing::TestParamInfo<ActTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
using namespace transformer_engine;
template <typename IT, typename OT, typename CT>
void compute_ref_geglu_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h,
const size_t N, const size_t H) {
CT amax = 0.;
const int col = H * 2;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT gelu_elt = CT(input_h[i * col + j]);
gelu_elt = 0.5f * gelu_elt *
(1.0f + tanhf(0.79788456F * gelu_elt * (1.0f + 0.044715f * gelu_elt * gelu_elt)));
CT gate_elt = CT(input_h[i * col + H + j]);
CT elt = gelu_elt * gate_elt;
output_h[i * H + j] = OT(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTestGEGLU(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({N, H * 2}, itype);
Tensor output({N, H}, otype);
fillUniform(&input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
nvte_geglu(input.data(), output.data(), 0);
float ref_amax;
compute_ref_geglu_cast(input.cpu_dptr<IType>(), ref_output.get(), output.scale(), &ref_amax, N,
H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
}
class GeGLUTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
TEST_P(GeGLUTestSuite, TestGeGLU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
output_type, OutputType,
performTestGEGLU<InputType, OutputType>(size.first, size.second);););
}
namespace {
std::vector<std::pair<size_t, size_t>> test_cases = {
{4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest, GeGLUTestSuite,
::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types), ::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<GeGLUTestSuite::ParamType> &info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <type_traits>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include "../test_common.h"
using namespace transformer_engine;
template <typename IT, typename OT, typename CT>
void compute_ref_gelu_cast(const IT *input_h,
OT *output_h,
const CT scale,
CT *amax_h,
const size_t N,
const size_t H) {
CT amax = 0.;
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = CT(input_h[i * H + j]);
elt = 0.5f * elt * (1.0f + tanhf(0.79788456F * elt *
(1.0f + 0.044715f * elt * elt)));
output_h[i * H + j] = OT(scale * elt);
amax = std::abs(elt) > amax ? std::abs(elt) : amax;
}
}
*amax_h = amax;
}
template <typename IType, typename OType>
void performTestGelu(const size_t N, const size_t H) {
using namespace test;
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;
Tensor input({ N, H }, itype);
Tensor output({ N, H }, otype);
fillUniform(&input);
setRandomScale(&output);
std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
nvte_gelu(input.data(), output.data(), 0);
float ref_amax;
compute_ref_gelu_cast(input.cpu_dptr<IType>(), ref_output.get(),
output.scale(), &ref_amax, N, H);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
}
auto [atol, rtol] = getTolerances(otype);
compareResults("output_gelu", output, ref_output.get(), atol, rtol);
}
class GELUTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};
TEST_P(GELUTestSuite, TestGELU) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestGelu<InputType, OutputType>(size.first, size.second);
);
);
}
namespace {
std::vector<std::pair<size_t, size_t>> gelu_test_cases = {{2048, 12288},
{768, 1024},
{256, 65536},
{65536, 128},
{256, 256},
{257, 259},
{128, 128+1}};
} // namespace
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
GELUTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::ValuesIn(test::all_fp_types),
::testing::ValuesIn(gelu_test_cases)),
[](const testing::TestParamInfo<GELUTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
...@@ -47,7 +47,8 @@ __device__ inline OType qgelu(const IType val, const Empty& e) { ...@@ -47,7 +47,8 @@ __device__ inline OType qgelu(const IType val, const Empty& e) {
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) { __device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) + sigmoid<float, float>(1.702f * cval, e); return 1.702f * cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment