test_rmsnorm.cu 9.96 KB
Newer Older
zlsh80826's avatar
zlsh80826 committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
zlsh80826's avatar
zlsh80826 committed
3
4
5
6
7
8
9
10
11
12
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
Tim Moon's avatar
Tim Moon committed
13
14
15
16
17
18
19

#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/transformer_engine.h>
zlsh80826's avatar
zlsh80826 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include "../test_common.h"

using namespace transformer_engine;
using namespace test;

namespace {

template <typename InputType>
void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, const size_t H,
                       const double epsilon) {
  using compute_t = float;
  for (size_t i = 0; i < N; ++i) {
    compute_t sum = 0;
    for (size_t j = 0; j < H; ++j) {
      compute_t current = static_cast<compute_t>(data[i * H + j]);
      sum += (current) * (current);
    }
    sum = sum / H;
    compute_t rs = rsqrtf(sum + epsilon);
    rsigma[i] = rs;
  }
}

template <typename InputType, typename OutputType>
void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output,
                        const float *rsigma, const size_t N, const size_t H, float *amax,
46
                        float scale, const bool zero_centered_gamma) {
zlsh80826's avatar
zlsh80826 committed
47
48
49
50
51
  using compute_t = float;
  compute_t current_max = -1e100;
  for (size_t i = 0; i < N; ++i) {
    for (size_t j = 0; j < H; ++j) {
      compute_t current = static_cast<compute_t>(data[i * H + j]);
52
53
54
55
56
      compute_t g = static_cast<compute_t>(gamma[j]);
      if (zero_centered_gamma) {
        g += 1;
      }
      compute_t tmp = current * rsigma[i] * g;
zlsh80826's avatar
zlsh80826 committed
57
58
59
60
61
62
63
64
65
66
      output[i * H + j] = static_cast<OutputType>(tmp * scale);
      current_max = fmaxf(current_max, fabsf(tmp));
    }
  }
  *amax = current_max;
}

template <typename InputType, typename OutputType>
void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma,
                          const InputType *gamma, InputType *data_grad, InputType *gamma_grad,
67
                          const size_t N, const size_t H, const bool zero_centered_gamma) {
zlsh80826's avatar
zlsh80826 committed
68
69
70
71
72
73
74
75
76
  using compute_t = float;
  std::vector<compute_t> dgamma(H, 0.f);

  for (size_t i = 0; i < N; ++i) {
    // Reductions
    compute_t mdyy = 0;
    for (size_t j = 0; j < H; ++j) {
      const compute_t x = static_cast<compute_t>(data[i * H + j]);
      const compute_t y = x * rsigma[i];
77
78
79
80
      compute_t g = static_cast<compute_t>(gamma[j]);
      if (zero_centered_gamma) {
        g += 1;
      }
zlsh80826's avatar
zlsh80826 committed
81
82
83
84
85
86
87
88
89
90
91
      const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
      const compute_t dy = g * dz;
      dgamma[j] += y * dz;
      mdyy += dy * y;
    }
    mdyy /= H;

    // Input grads
    for (size_t j = 0; j < H; ++j) {
      const compute_t x = static_cast<compute_t>(data[i * H + j]);
      const compute_t y = x * rsigma[i];
92
93
94
95
      compute_t g = static_cast<compute_t>(gamma[j]);
      if (zero_centered_gamma) {
        g += 1;
      }
zlsh80826's avatar
zlsh80826 committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
      const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
      const compute_t dy = g * dz;
      const compute_t dx = rsigma[i] * (dy - mdyy * y);
      data_grad[i * H + j] = static_cast<InputType>(dx);
    }
  }

  // Weight grads
  for (size_t j = 0; j < H; ++j) {
    gamma_grad[j] = static_cast<InputType>(dgamma[j]);
  }
}

template <typename InputType, typename OutputType>
110
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) {
zlsh80826's avatar
zlsh80826 committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
  if (sizeof(InputType) < sizeof(OutputType)) {
    GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType";
    return;
  }
  using WeightType = InputType;
  DType itype = TypeInfo<InputType>::dtype;
  DType wtype = TypeInfo<WeightType>::dtype;
  DType otype = TypeInfo<OutputType>::dtype;

  if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
      (itype == DType::kFloat16 && otype == DType::kBFloat16)) {
    GTEST_SKIP() << "RMSNorm kernel does not support mixing Float16 and BFloat16";
    return;
  }

  Tensor input({N, H}, itype);
  Tensor z({N, H}, otype);
  Tensor gamma({H}, wtype);
  Tensor rsigma({N}, DType::kFloat32);
  Tensor dz({N, H}, wtype);
  Tensor dx({N, H}, itype);
  Tensor dgamma({H}, wtype);
  Tensor workspace, barrier, dgamma_part;

  fillUniform(&input);
  fillUniform(&gamma);
  fillUniform(&dz);
  setRandomScale(&z);

  std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
  std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
  std::unique_ptr<InputType[]> ref_dx = std::make_unique<InputType[]>(N * H);
  std::unique_ptr<WeightType[]> ref_dgamma = std::make_unique<InputType[]>(H);

  cudaDeviceProp prop;
  cudaGetDeviceProperties(&prop, 0);

  // Forward kernel
  float epsilon = 1e-5;
150
151
152
  auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
  fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
               prop.multiProcessorCount, workspace.data(), barrier.data());
zlsh80826's avatar
zlsh80826 committed
153
154
  workspace = Tensor(workspace.shape(), workspace.dtype());
  barrier = Tensor(barrier.shape(), barrier.dtype());
155
156
  fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0,
               prop.multiProcessorCount, workspace.data(), barrier.data());
zlsh80826's avatar
zlsh80826 committed
157
158

  // Backward kernel
159
160
161
162
  auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd;
  bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
               dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
               barrier.data());
zlsh80826's avatar
zlsh80826 committed
163
164
165
  workspace = Tensor(workspace.shape(), workspace.dtype());
  barrier = Tensor(barrier.shape(), barrier.dtype());
  dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype());
166
167
168
  bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(),
               dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(),
               barrier.data());
zlsh80826's avatar
zlsh80826 committed
169
170
171
172
173
174
175
176

  // Reference implementations
  // use the GPU stats to tighten the tolerances
  rsigma.to_cpu();
  float ref_amax;
  compute_ref_stats(input.cpu_dptr<InputType>(), ref_rsigma.get(), N, H, epsilon);
  float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
  compute_ref_output(input.cpu_dptr<InputType>(), gamma.cpu_dptr<WeightType>(), ref_output.get(),
177
178
                     rsigma.cpu_dptr<float>(), N, H, &ref_amax, ref_scale,
                     zero_centered_gamma);
zlsh80826's avatar
zlsh80826 committed
179
180
  compute_ref_backward(dz.cpu_dptr<WeightType>(), input.cpu_dptr<InputType>(),
                       rsigma.cpu_dptr<float>(), gamma.cpu_dptr<WeightType>(), ref_dx.get(),
181
                       ref_dgamma.get(), N, H, zero_centered_gamma);
zlsh80826's avatar
zlsh80826 committed
182
183
184
185
186
187
188
189

  cudaDeviceSynchronize();
  auto err = cudaGetLastError();
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

  auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
  if (isFp8Type(otype)) {
    compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
190
191
    float ref_scale_inv = 1.f / z.scale();
    compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
zlsh80826's avatar
zlsh80826 committed
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
  }

  auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
  rtol_stats = 5e-5;
  compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats);

  auto [atol, rtol] = getTolerances(otype);
  atol = 1e-8;
  compareResults("output", z, ref_output.get(), atol, rtol);

  double atol_bwd = 5e-6;
  double rtol_bwd = 1e-4;
  compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd);
  compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd);
}

std::vector<std::pair<size_t, size_t>> test_cases = {
    {2048, 4096}, {768, 2048}, {256, 1024}, {128, 768}, {64, 512}, {173, 409},  // Primes 40, 80
    {71, 3571},                                                                 // Primes 20, 500
    {29, 17389}};                                                               // Primes 10, 2000

}  // namespace

215
216
217
218
class RMSNormTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
                                                                    transformer_engine::DType,
                                                                    std::pair<size_t, size_t>,
                                                                    bool>> {};
zlsh80826's avatar
zlsh80826 committed
219
220
221
222
223
224
225
226

TEST_P(RMSNormTestSuite, TestRMSNorm) {
  using namespace transformer_engine;
  using namespace test;

  const DType input_type = std::get<0>(GetParam());
  const DType output_type = std::get<1>(GetParam());
  const auto size = std::get<2>(GetParam());
227
  const bool zero_centered_gamma = std::get<3>(GetParam());
zlsh80826's avatar
zlsh80826 committed
228

229
230
231
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
      performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma);););
zlsh80826's avatar
zlsh80826 committed
232
233
234
235
236
237
238
}

INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite,
                         ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16,
                                                              DType::kFloat16),
                                            ::testing::Values(DType::kFloat32, DType::kBFloat16,
                                                              DType::kFloat16, DType::kFloat8E4M3),
239
240
                                            ::testing::ValuesIn(test_cases),
                                            ::testing::Values(false, true)),
zlsh80826's avatar
zlsh80826 committed
241
                         [](const testing::TestParamInfo<RMSNormTestSuite::ParamType> &info) {
242
243
244
245
246
247
                           std::string name =
                             test::typeName(std::get<0>(info.param)) + "X" +
                             test::typeName(std::get<1>(info.param)) + "X" +
                             std::to_string(std::get<2>(info.param).first) + "X" +
                             std::to_string(std::get<2>(info.param).second) + "X" +
                             std::to_string(std::get<3>(info.param));
zlsh80826's avatar
zlsh80826 committed
248
249
                           return name;
                         });