test_normalization.cu 10.7 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>

#include <cuda_bf16.h>
yuguo's avatar
yuguo committed
15
#include <cuda_fp16.h>
16
17
18
19
20
21
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
22
#include "test_normalization.h"
23
24
25
26
27
28
29
30

using namespace transformer_engine;
using namespace test;

namespace {

template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
31
                 NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
32
33
34
35
  if (sizeof(InputType) < sizeof(OutputType)) {
    GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
    return;
  }
36
37
38
39
40

  if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
    GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
  }

41
42
43
44
45
46
47
48
49
50
51
  using WeightType = InputType;
  DType itype = TypeInfo<InputType>::dtype;
  DType wtype = TypeInfo<WeightType>::dtype;
  DType otype = TypeInfo<OutputType>::dtype;

  if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
      (itype == DType::kFloat16 && otype == DType::kBFloat16)) {
    GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16";
    return;
  }

52
53
54
55
56
57
58
59
60
61
  Tensor input("input", { N, H }, itype);
  Tensor z("z", { N, H }, otype);
  Tensor gamma("gamma", { H }, wtype);
  Tensor beta("beta", { H }, wtype);
  Tensor mu("mu", { N }, DType::kFloat32);
  Tensor rsigma("rsigma", { N }, DType::kFloat32);
  Tensor dz("dz", { N, H }, wtype);
  Tensor dx("dx", { N, H }, itype);
  Tensor dgamma("dgamma", { H }, wtype);
  Tensor dbeta("dbeta", { H }, wtype);
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
  Tensor workspace_fwd, workspace_bwd;

  fillUniform(&input);
  fillUniform(&gamma);
  fillUniform(&beta);
  setRandomScale(&z);
  fillUniform(&dz);

  std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
  std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
  std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
  std::unique_ptr<InputType[]> ref_dx = std::make_unique<InputType[]>(N * H);
  std::unique_ptr<WeightType[]> ref_dgamma = std::make_unique<InputType[]>(H);
  std::unique_ptr<WeightType[]> ref_dbeta = std::make_unique<InputType[]>(H);

  cudaDeviceProp prop;
  cudaGetDeviceProperties(&prop, 0);

80
81
82
83
84
  if ((!use_cudnn || !zero_centered_gamma) && zero_centered_gamma_in_weight_dtype) {
    // Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation
    GTEST_SKIP() << "Zero-centered gamma in weight dtype is only supported with cuDNN backend";
  }

85
86
87
  if (use_cudnn){
    nvte_enable_cudnn_norm_fwd(true);
    nvte_enable_cudnn_norm_bwd(true);
88
89
90
91
92
93
94
95


    // Zero-centered gamma in weight dtype only supported by CuDNN backend currently
    if (zero_centered_gamma_in_weight_dtype) {
      nvte_enable_zero_centered_gamma_in_weight_dtype(true);
    } else {
      nvte_enable_zero_centered_gamma_in_weight_dtype(false);
    }
96
97
98
99
100
101
102
103
  }

  // Forward kernel
  float epsilon = 1e-5;
  if (norm_type == LayerNorm){
    nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
                       z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);
104
    workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
105
106
107
108
109
110
111
112
113
    nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
                       z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);

    nvte_layernorm_bwd(dz.data(), input.data(),
                       mu.data(), rsigma.data(), gamma.data(),
                       dx.data(), dgamma.data(), dbeta.data(),
                       workspace_bwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);
114
    workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
115
116
117
118
119
120
121
122
123
    nvte_layernorm_bwd(dz.data(), input.data(),
                       mu.data(), rsigma.data(), gamma.data(),
                       dx.data(), dgamma.data(), dbeta.data(),
                       workspace_bwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);
  } else {
    nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
                     z.data(), rsigma.data(), workspace_fwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);
124
    workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
125
126
127
128
129
130
131
132
    nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
                     z.data(), rsigma.data(), workspace_fwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);

    nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
                     dx.data(), dgamma.data(),
                     workspace_bwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);
133
    workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
134
135
136
137
138
139
140
141
142
    nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
                     dx.data(), dgamma.data(),
                     workspace_bwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);
  }

  if (use_cudnn){
    nvte_enable_cudnn_norm_fwd(false);
    nvte_enable_cudnn_norm_bwd(false);
143
144
145
146
147

    // Zero-centered gamma in weight dtype only supported by CuDNN backend currently
    if (zero_centered_gamma_in_weight_dtype) {
      nvte_enable_zero_centered_gamma_in_weight_dtype(false);
    }
148
149
150
151
152
153
154
  }

  // Reference implementations
  // use the GPU stats to tighten the tolerances
  mu.to_cpu();
  rsigma.to_cpu();
  float ref_amax;
155
  compute_ref_stats(norm_type, input.rowwise_cpu_dptr<InputType>(), ref_mu.get(),
156
157
                    ref_rsigma.get(), N, H, epsilon);
  float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
158
159
160
  compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
                     gamma.rowwise_cpu_dptr<WeightType>(),
                     beta.rowwise_cpu_dptr<WeightType>(),
161
                     ref_output.get(),
162
163
                     mu.rowwise_cpu_dptr<float>(),
                     rsigma.rowwise_cpu_dptr<float>(),
164
165
166
167
                     N, H,
                     &ref_amax,
                     ref_scale,
                     zero_centered_gamma,
168
169
                     use_cudnn,
                     zero_centered_gamma_in_weight_dtype);
170
171
172
173
  compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
                       input.rowwise_cpu_dptr<InputType>(),
                       mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
                       gamma.rowwise_cpu_dptr<WeightType>(),
174
175
                       ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
                       N, H, zero_centered_gamma,
176
177
                       use_cudnn,
                       zero_centered_gamma_in_weight_dtype);
178
179
180
181
182
183
184
185
186

  cudaDeviceSynchronize();
  auto err = cudaGetLastError();
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

  auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
  if (isFp8Type(otype)) {
    compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
    float ref_scale_inv = 1.f / z.scale();
187
    compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
188
189
190
191
  }

  auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
  rtol_stats = 5e-5;
192
193
  compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats);
  compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats);
194
195
196
197
198

  auto [atol, rtol] = getTolerances(otype);
  if (otype == DType::kFloat32) {
    atol = 5e-7;
  }
199
  compareResults("output", z, ref_output.get(), true, atol, rtol);
200
201
202

  double atol_bwd = 5e-4;
  double rtol_bwd = 5e-4;
203
204
205
  compareResults("dx", dx, ref_dx.get(), true, atol_bwd, rtol_bwd);
  compareResults("dgamma", dgamma, ref_dgamma.get(), true, atol_bwd, rtol_bwd);
  compareResults("dbeta", dbeta, ref_dbeta.get(), true, atol_bwd, rtol_bwd);
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
}

std::vector<std::pair<size_t, size_t>> test_cases = {
  {71, 229},
  {29, 541},
  {768, 6144},
  {2048, 12288},
};

}  // namespace

class NormTestSuite : public ::testing::TestWithParam<std::tuple<bool,
NormType,
transformer_engine::DType,
                                                               transformer_engine::DType,
                                                               std::pair<size_t, size_t>,
222
                                                               bool,
223
224
225
226
227
228
229
230
231
232
233
234
                                                               bool>> {};

TEST_P(NormTestSuite, TestNorm) {
    using namespace transformer_engine;
    using namespace test;

  const bool use_cudnn = std::get<0>(GetParam());
  const NormType norm_type = std::get<1>(GetParam());
    const DType input_type = std::get<2>(GetParam());
    const DType output_type = std::get<3>(GetParam());
    const auto size = std::get<4>(GetParam());
    const bool zero_centered_gamma = std::get<5>(GetParam());
235
    const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());
236
237
238

    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
239
        performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
240
241
242
243
244
      );
    );
}

INSTANTIATE_TEST_SUITE_P(
245
246
247
248
249
250
251
252
  OperatorTest,
  NormTestSuite,
  ::testing::Combine(
    ::testing::Values(true, false),
    ::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
    ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
    ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
    ::testing::ValuesIn(test_cases),
253
    ::testing::Values(false, true),
254
255
    ::testing::Values(false, true)),
  [](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
256
    auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
257
258
259
260
261
262
263
    std::string name =
      backend +
      normToString.at(std::get<1>(info.param)) + "_" +
      test::typeName(std::get<2>(info.param)) + "X" +
      test::typeName(std::get<3>(info.param)) + "X" +
      std::to_string(std::get<4>(info.param).first) + "X" +
      std::to_string(std::get<4>(info.param).second) + "X" +
264
265
      std::to_string(std::get<5>(info.param)) + "X" +
      std::to_string(std::get<6>(info.param));
266
267
    return name;
  });