test_normalization.cu 10.9 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>

#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>

#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
21
#include "test_normalization.h"
22
23
24
25
26
27
28
29

using namespace transformer_engine;
using namespace test;

namespace {

template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
30
                 NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
31
32
33
34
  if (sizeof(InputType) < sizeof(OutputType)) {
    GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
    return;
  }
35
36
37
38
39

  if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
    GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
  }

40
41
42
43
44
45
46
47
48
49
50
  using WeightType = InputType;
  DType itype = TypeInfo<InputType>::dtype;
  DType wtype = TypeInfo<WeightType>::dtype;
  DType otype = TypeInfo<OutputType>::dtype;

  if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
      (itype == DType::kFloat16 && otype == DType::kBFloat16)) {
    GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16";
    return;
  }

51
52
53
54
55
56
57
58
59
60
  Tensor input("input", std::vector<size_t>{ N, H }, itype);
  Tensor z("z", std::vector<size_t>{ N, H }, otype);
  Tensor gamma("gamma", std::vector<size_t>{ H }, wtype);
  Tensor beta("beta", std::vector<size_t>{ H }, wtype);
  Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
  Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
  Tensor dz("dz", std::vector<size_t>{ N, H }, wtype);
  Tensor dx("dx", std::vector<size_t>{ N, H }, itype);
  Tensor dgamma("dgamma", std::vector<size_t>{ H }, wtype);
  Tensor dbeta("dbeta", std::vector<size_t>{ H }, wtype);
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  Tensor workspace_fwd, workspace_bwd;

  fillUniform(&input);
  fillUniform(&gamma);
  fillUniform(&beta);
  setRandomScale(&z);
  fillUniform(&dz);

  std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
  std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
  std::unique_ptr<float[]> ref_rsigma = std::make_unique<float[]>(N);
  std::unique_ptr<InputType[]> ref_dx = std::make_unique<InputType[]>(N * H);
  std::unique_ptr<WeightType[]> ref_dgamma = std::make_unique<InputType[]>(H);
  std::unique_ptr<WeightType[]> ref_dbeta = std::make_unique<InputType[]>(H);

  cudaDeviceProp prop;
  cudaGetDeviceProperties(&prop, 0);

79
80
81
82
83
  if ((!use_cudnn || !zero_centered_gamma) && zero_centered_gamma_in_weight_dtype) {
    // Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation
    GTEST_SKIP() << "Zero-centered gamma in weight dtype is only supported with cuDNN backend";
  }

84
85
86
  if (use_cudnn){
    nvte_enable_cudnn_norm_fwd(true);
    nvte_enable_cudnn_norm_bwd(true);
87
88
89
90
91
92
93
94


    // Zero-centered gamma in weight dtype only supported by CuDNN backend currently
    if (zero_centered_gamma_in_weight_dtype) {
      nvte_enable_zero_centered_gamma_in_weight_dtype(true);
    } else {
      nvte_enable_zero_centered_gamma_in_weight_dtype(false);
    }
95
96
97
98
99
100
101
102
  }

  // Forward kernel
  float epsilon = 1e-5;
  if (norm_type == LayerNorm){
    nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
                       z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);
103
    workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
104
105
106
107
108
109
110
111
112
    nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
                       z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);

    nvte_layernorm_bwd(dz.data(), input.data(),
                       mu.data(), rsigma.data(), gamma.data(),
                       dx.data(), dgamma.data(), dbeta.data(),
                       workspace_bwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);
113
    workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
114
115
116
117
118
119
120
121
122
    nvte_layernorm_bwd(dz.data(), input.data(),
                       mu.data(), rsigma.data(), gamma.data(),
                       dx.data(), dgamma.data(), dbeta.data(),
                       workspace_bwd.data(),
                       prop.multiProcessorCount, zero_centered_gamma, 0);
  } else {
    nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
                     z.data(), rsigma.data(), workspace_fwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);
123
    workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype());
124
125
126
127
128
129
130
131
    nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon,
                     z.data(), rsigma.data(), workspace_fwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);

    nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
                     dx.data(), dgamma.data(),
                     workspace_bwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);
132
    workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
133
134
135
136
137
138
139
140
141
    nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
                     dx.data(), dgamma.data(),
                     workspace_bwd.data(),
                     prop.multiProcessorCount, zero_centered_gamma, 0);
  }

  if (use_cudnn){
    nvte_enable_cudnn_norm_fwd(false);
    nvte_enable_cudnn_norm_bwd(false);
142
143
144
145
146

    // Zero-centered gamma in weight dtype only supported by CuDNN backend currently
    if (zero_centered_gamma_in_weight_dtype) {
      nvte_enable_zero_centered_gamma_in_weight_dtype(false);
    }
147
148
149
150
151
152
153
  }

  // Reference implementations
  // use the GPU stats to tighten the tolerances
  mu.to_cpu();
  rsigma.to_cpu();
  float ref_amax;
154
  compute_ref_stats(norm_type, input.rowwise_cpu_dptr<InputType>(), ref_mu.get(),
155
156
                    ref_rsigma.get(), N, H, epsilon);
  float ref_scale = isFp8Type(otype) ? z.scale() : 1.f;
157
158
159
  compute_ref_output(norm_type, input.rowwise_cpu_dptr<InputType>(),
                     gamma.rowwise_cpu_dptr<WeightType>(),
                     beta.rowwise_cpu_dptr<WeightType>(),
160
                     ref_output.get(),
161
162
                     mu.rowwise_cpu_dptr<float>(),
                     rsigma.rowwise_cpu_dptr<float>(),
163
164
165
166
                     N, H,
                     &ref_amax,
                     ref_scale,
                     zero_centered_gamma,
167
168
                     use_cudnn,
                     zero_centered_gamma_in_weight_dtype);
169
170
171
172
  compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
                       input.rowwise_cpu_dptr<InputType>(),
                       mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
                       gamma.rowwise_cpu_dptr<WeightType>(),
173
174
                       ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(),
                       N, H, zero_centered_gamma,
175
176
                       use_cudnn,
                       zero_centered_gamma_in_weight_dtype);
177
178
179
180
181
182
183
184
185

  cudaDeviceSynchronize();
  auto err = cudaGetLastError();
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

  auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
  if (isFp8Type(otype)) {
    compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax);
    float ref_scale_inv = 1.f / z.scale();
186
    compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
187
188
189
190
  }

  auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32);
  rtol_stats = 5e-5;
191
192
  compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats);
  compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats);
193
194
195
196
197

  auto [atol, rtol] = getTolerances(otype);
  if (otype == DType::kFloat32) {
    atol = 5e-7;
  }
198
  compareResults("output", z, ref_output.get(), true, atol, rtol);
199
200
201

  double atol_bwd = 5e-4;
  double rtol_bwd = 5e-4;
202
203
204
  compareResults("dx", dx, ref_dx.get(), true, atol_bwd, rtol_bwd);
  compareResults("dgamma", dgamma, ref_dgamma.get(), true, atol_bwd, rtol_bwd);
  compareResults("dbeta", dbeta, ref_dbeta.get(), true, atol_bwd, rtol_bwd);
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
}

std::vector<std::pair<size_t, size_t>> test_cases = {
  {71, 229},
  {29, 541},
  {768, 6144},
  {2048, 12288},
};

}  // namespace

class NormTestSuite : public ::testing::TestWithParam<std::tuple<bool,
NormType,
transformer_engine::DType,
                                                               transformer_engine::DType,
                                                               std::pair<size_t, size_t>,
221
                                                               bool,
222
223
224
225
226
227
228
229
230
231
232
233
                                                               bool>> {};

TEST_P(NormTestSuite, TestNorm) {
    using namespace transformer_engine;
    using namespace test;

  const bool use_cudnn = std::get<0>(GetParam());
  const NormType norm_type = std::get<1>(GetParam());
    const DType input_type = std::get<2>(GetParam());
    const DType output_type = std::get<3>(GetParam());
    const auto size = std::get<4>(GetParam());
    const bool zero_centered_gamma = std::get<5>(GetParam());
234
    const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());
235
236
237

    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
238
        performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
239
240
241
242
243
      );
    );
}

INSTANTIATE_TEST_SUITE_P(
244
245
246
247
248
249
250
251
  OperatorTest,
  NormTestSuite,
  ::testing::Combine(
    ::testing::Values(true, false),
    ::testing::Values(NormType::LayerNorm, NormType::RMSNorm),
    ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
    ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
    ::testing::ValuesIn(test_cases),
252
    ::testing::Values(false, true),
253
254
    ::testing::Values(false, true)),
  [](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
255
    auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
256
257
258
259
260
261
262
    std::string name =
      backend +
      normToString.at(std::get<1>(info.param)) + "_" +
      test::typeName(std::get<2>(info.param)) + "X" +
      test::typeName(std::get<3>(info.param)) + "X" +
      std::to_string(std::get<4>(info.param).first) + "X" +
      std::to_string(std::get<4>(info.param).second) + "X" +
263
264
      std::to_string(std::get<5>(info.param)) + "X" +
      std::to_string(std::get<6>(info.param));
265
266
    return name;
  });