test_dgeglu.cu 4.12 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <type_traits>
Tim Moon's avatar
Tim Moon committed
14
15
16
17
18
19

#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/activation.h>
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#include "../test_common.h"

using namespace transformer_engine;

namespace {

template <typename CType, typename IType>
inline CType gelu(const IType val) {
  CType cval = val;
  return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
}

template <typename CType, typename IType>
inline CType dgelu(const IType val) {
  CType cval = val;
  const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
  return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
         0.5f * (1.f + tanh_out);
}

template <typename IT, typename OT, typename CT>
void compute_ref_dgeglu(const IT *grad_h, const IT *input_h, OT *output_h, const size_t N,
                        const size_t H) {
  const size_t col = H * 2;

  for (size_t i = 0; i < N; i++) {
    for (size_t j = 0; j < H; j++) {
      CT grad_elt = CT(grad_h[i * H + j]);
      CT gelu_elt = CT(input_h[i * col + j]);
      CT gate_elt = CT(input_h[i * col + H + j]);

      CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
      CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);

      output_h[i * col + j] = OT(after_dgelu);
      output_h[i * col + H + j] = OT(after_dgate);
    }
  }
}

template <typename IType, typename OType>
void performTestDGeGLU(const size_t N, const size_t H) {
  using namespace test;

  using CType = fp32;

  DType itype = TypeInfo<IType>::dtype;
  DType otype = TypeInfo<OType>::dtype;

  Tensor grad({N, H}, itype);
  Tensor input({N, H * 2}, itype);
  Tensor output({N, H * 2}, otype);

  fillUniform(&grad);
  fillUniform(&input);

  std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H * 2);

  nvte_dgeglu(grad.data(), input.data(), output.data(), 0);

  compute_ref_dgeglu<IType, OType, CType>(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
                                          ref_output.get(), N, H);

  cudaDeviceSynchronize();
  auto err = cudaGetLastError();
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

  auto [atol, rtol] = getTolerances(otype);
  compareResults("output_dgelu", output, ref_output.get(), atol, rtol);
}

std::vector<std::pair<size_t, size_t>> test_cases = {
    {4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};

}  // namespace

class DGeGLUTestSuite
    : public ::testing::TestWithParam<std::tuple<
          transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};

TEST_P(DGeGLUTestSuite, TestDGeGLU) {
  using namespace transformer_engine;
  using namespace test;

  const DType input_type = std::get<0>(GetParam());
  const DType output_type = std::get<1>(GetParam());
  const auto size = std::get<2>(GetParam());

  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
      input_type, InputType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
          output_type, OutputType,
          performTestDGeGLU<InputType, OutputType>(size.first, size.second);););
}

INSTANTIATE_TEST_SUITE_P(
    OperatorTest, DGeGLUTestSuite,
    ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
                       ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
                       ::testing::ValuesIn(test_cases)),
    [](const testing::TestParamInfo<DGeGLUTestSuite::ParamType> &info) {
      std::string name = test::typeName(std::get<0>(info.param)) + "X" +
                         test::typeName(std::get<1>(info.param)) + "X" +
                         std::to_string(std::get<2>(info.param).first) + "X" +
                         std::to_string(std::get<2>(info.param).second);
      return name;
    });