/************************************************************************* * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include #include #include #include #include #include #include #include #include #include #include "../test_common.h" using namespace transformer_engine; namespace { // forward float gelu(const float x) { return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x))); } float silu(const float x) { return x / (1 + expf(-x)); } float relu(const float x) { return x > 0 ? x : 0; } float srelu(const float x) { return x > 0 ? x * x : 0; } float qgelu(const float x) { return x / (1 + expf(-1.702f * x)); } // backward float dgelu(const float x) { const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x)); return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + 0.5f * (1.f + tanh_out); } float dsilu(const float x) { const float sigmoid = 1.f / (1 + expf(-x)); return x * sigmoid * (1.f - sigmoid) + sigmoid; } float drelu(const float x) { return x > 0.f ? 1.f : 0.f; } float dsrelu(const float x) { return fmaxf(2.f * x, 0.f); } float dqgelu(const float x) { const float sigmoid = 1.f / (1 + expf(-1.702f * x)); return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid; } } // namespace template void compute_ref_act_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h, const size_t N, const size_t H) { CT amax = 0.; for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); elt = act(elt); output_h[i * H + j] = static_cast(scale * elt); amax = std::abs(elt) > amax ? std::abs(elt) : amax; } } *amax_h = amax; } template void compute_ref_dact_cast(const IT *input_h, const IT *grad_h, OT *output_h, const size_t N, const size_t H) { using CT = float; for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); elt = dact(elt); CT grad = static_cast(grad_h[i * H + j]); output_h[i * H + j] = static_cast(grad * elt); } } } template void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h, const size_t N, const size_t H) { CT amax = 0.; const int col = H * 2; for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT gelu_elt = static_cast(input_h[i * col + j]); gelu_elt = act(gelu_elt); CT gate_elt = static_cast(input_h[i * col + H + j]); CT elt = gelu_elt * gate_elt; output_h[i * H + j] = static_cast(scale * elt); amax = std::abs(elt) > amax ? std::abs(elt) : amax; } } *amax_h = amax; } template void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h, const size_t N, const size_t H) { const int col = H * 2; using CT = float; for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT grad = static_cast(grad_h[i * H + j]); CT gelu_elt = static_cast(input_h[i * col + j]); CT gate_elt = static_cast(input_h[i * col + H + j]); output_h[i * col + H + j] = static_cast(grad * act(gelu_elt)); gelu_elt = dact(gelu_elt); CT elt = gelu_elt * gate_elt; output_h[i * col + j] = static_cast(grad * elt); } } } template void performTest(const size_t N, const size_t H) { using namespace test; DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; Tensor input({ N, H }, itype); Tensor output({ N, H }, otype); Tensor igrad({ N, H }, itype); Tensor ograd({ N, H }, itype); fillUniform(&input); fillUniform(&ograd); setRandomScale(&output); std::unique_ptr ref_output = std::make_unique(N*H); std::unique_ptr ref_igrad = std::make_unique(N*H); nvte_act(input.data(), output.data(), 0); float ref_amax; compute_ref_act_cast(input.cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_act", output, ref_output.get(), atol, rtol); nvte_dact(ograd.data(), input.data(), igrad.data(), 0); compute_ref_dact_cast(input.cpu_dptr(), ograd.cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); { auto [atol, rtol] = getTolerances(otype); compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); } } template void performTestGLU(const size_t N, const size_t H) { using namespace test; DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; Tensor input({N, H * 2}, itype); Tensor output({N, H}, otype); Tensor igrad({ N, H * 2 }, itype); Tensor ograd({ N, H }, itype); fillUniform(&input); fillUniform(&ograd); setRandomScale(&output); std::unique_ptr ref_output = std::make_unique(N * H); std::unique_ptr ref_igrad = std::make_unique(2 * N * H); nvte_act(input.data(), output.data(), 0); float ref_amax; compute_ref_glu_act_cast(input.cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_gelu", output, ref_output.get(), atol, rtol); nvte_dact(ograd.data(), input.data(), igrad.data(), 0); compute_ref_dglu_act_cast(input.cpu_dptr(), ograd.cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); { auto [atol, rtol] = getTolerances(otype); compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); } } class ActTestSuite : public ::testing::TestWithParam>> {}; TEST_P(ActTestSuite, TestGELU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, performTest(size.first, size.second); ); ); } TEST_P(ActTestSuite, TestSILU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, performTest(size.first, size.second); ); ); } TEST_P(ActTestSuite, TestRELU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, performTest(size.first, size.second); ); ); } TEST_P(ActTestSuite, TestQGELU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, performTest(size.first, size.second); ); ); } TEST_P(ActTestSuite, TestSRELU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, performTest(size.first, size.second); ); ); } TEST_P(ActTestSuite, TestGeGLU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( output_type, OutputType, performTestGLU(size.first, size.second););); } TEST_P(ActTestSuite, TestReGLU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( output_type, OutputType, performTestGLU(size.first, size.second););); } TEST_P(ActTestSuite, TestSwiGLU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( output_type, OutputType, performTestGLU(size.first, size.second););); } TEST_P(ActTestSuite, TestQGeGLU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( output_type, OutputType, performTestGLU(size.first, size.second););); } TEST_P(ActTestSuite, TestSReGLU) { using namespace transformer_engine; using namespace test; const DType input_type = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const auto size = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( output_type, OutputType, performTestGLU(size.first, size.second););); } namespace { std::vector> act_test_cases = {{2048, 12288}, {768, 2816}, {256, 65536}, {65536, 128}, {256, 256}, {257, 259}, {128, 128+1}}; } // namespace INSTANTIATE_TEST_SUITE_P( OperatorTest, ActTestSuite, ::testing::Combine( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::ValuesIn(test::all_fp_types), ::testing::ValuesIn(act_test_cases)), [](const testing::TestParamInfo& info) { std::string name = test::typeName(std::get<0>(info.param)) + "X" + test::typeName(std::get<1>(info.param)) + "X" + std::to_string(std::get<2>(info.param).first) + "X" + std::to_string(std::get<2>(info.param).second); return name; });