test_normalization.h 6.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

 #pragma once

#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>

#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

namespace test {
namespace {

enum NormType {
  LayerNorm,
  RMSNorm
};

std::map<NormType, std::string> normToString = {
  {NormType::LayerNorm, "LayerNorm"},
  {NormType::RMSNorm, "RmsNorm"}
};

template <typename InputType>
void compute_ref_stats(NormType norm_type,
                       const InputType *data, float *mu, float *rsigma,
                       const size_t N, const size_t H, const double epsilon){
  using compute_t = float;
  compute_t current, m;
  for (size_t i = 0; i < N; ++i) {
    compute_t sum = 0;
    for (size_t j = 0; j < H; ++j) {
      sum += static_cast<compute_t>(data[i * H + j]);
    }
    if (norm_type == LayerNorm){
      mu[i] = sum / H;
      m = mu[i];
    } else { m = 0;}

    compute_t sum_sq = 0;
    for (size_t j = 0; j < H; ++j) {
      current = static_cast<compute_t>(data[i * H + j]);
      sum_sq += (current - m) * (current - m);
    }
    rsigma[i] = rsqrtf((sum_sq / H) + epsilon);
  }
}

template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {

  using compute_t = float;

  // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
  // Remove the use_cudnn check here when it is supported by both backends.
  const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;

70
71
  if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
                std::is_same_v<InputType, fp4e2m1>){
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    compute_t g = static_cast<compute_t>(gamma);
    if (zero_centered_gamma) {
      g += static_cast<compute_t>(1.f);
    }
    return g;
  } else {
    if (zero_centered_gamma_in_weight_dtype){
      compute_t g = static_cast<compute_t>(0.f);
      InputType gi = gamma;
      if (zero_centered_gamma) {
        gi = gi + static_cast<InputType>(1.f);
      }
      g = static_cast<compute_t>(gi);
      return g;
    } else {
      compute_t g = static_cast<compute_t>(gamma);
      if (zero_centered_gamma) {
        g += static_cast<compute_t>(1.f);
      }
      return g;
    }
  }
}

template <typename InputType, typename OutputType>
void compute_ref_output(NormType norm_type,
                        const InputType *data, const InputType *gamma, const InputType *beta,
                        OutputType* output,
                        const float *mu, const float *rsigma,
                        const size_t N, const size_t H,
                        float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
  using compute_t = float;
  compute_t current_max = -1e100;
  for (size_t i = 0; i < N; ++i) {
    for (size_t j = 0; j < H; ++j) {
      compute_t current = static_cast<compute_t>(data[i * H + j]);
      compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);

      compute_t tmp;
      if (norm_type == LayerNorm) {
        tmp = (current - mu[i]) * rsigma[i] * g + static_cast<compute_t>(beta[j]);
      } else { // RMSNorm
        tmp = current * rsigma[i] * g;
      }

117
      // Write output (scaled only for fp8 paths)
118
      output[i * H + j] = static_cast<OutputType>(tmp * scale);
119
120
121
122
123
124
125
126
127
128

      // amax semantics:
      // - fp8_out (scale != 1): amax on pre-scale compute value 'tmp'
      // - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16)
      if (scale != 1.f) {
        current_max = fmaxf(current_max, fabsf(tmp));
      } else {
        OutputType out_t_val = static_cast<OutputType>(tmp);
        current_max = fmaxf(current_max, fabsf(static_cast<compute_t>(out_t_val)));
      }
129
130
131
132
133
134
135
136
137
138
    }
  }

  if (amax) {
    *amax = current_max;
  }
}


template <typename InputType, typename OutputType>
139
140
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad,
                          const OutputType *add, const InputType *data,
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
                          const float *mu, const float *rsigma,
                          const InputType *gamma,
                          InputType *data_grad,
                          InputType *gamma_grad, InputType *beta_grad,
                          const size_t N, const size_t H,
                          const bool zero_centered_gamma, const bool use_cudnn,
                          const bool cudnn_zero_centered_gamma_in_weight_dtype) {
  using compute_t = float;
  std::vector<compute_t> dgamma(H, 0.f);
  std::vector<compute_t> dbeta(H, 0.f);

  for (size_t i = 0 ; i < N; ++i) {
    // Reductions
    auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.;
    compute_t mdy = 0, mdyy = 0;
    for (size_t j = 0; j < H; ++j) {
      const compute_t x = static_cast<compute_t>(data[i * H + j]);
      const compute_t y = (x - local_mu) * rsigma[i];
      compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
      const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
      const compute_t dy = g * dz;
      dgamma[j] += y * dz;
      if (norm_type == LayerNorm) {
        dbeta[j] += dz;
        mdy += dy;
      }
      mdyy += dy * y;
    }
    mdy /= H;
    mdyy /= H;

    // Input grads
    for (size_t j = 0; j < H; ++j) {
      const compute_t x = static_cast<compute_t>(data[i * H + j]);
      const compute_t y = (x - local_mu) * rsigma[i];
      compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
      const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
      const compute_t dy = g * dz;
179
180
      const compute_t a = static_cast<compute_t>(add[i * H + j]);
      const compute_t dx = a + rsigma[i] * (dy - mdyy * y - mdy);
181
182
183
184
185
186
187
188
189
190
191
      data_grad[i * H + j] = static_cast<InputType>(dx);
    }
  }

  // Weight grads
  for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast<InputType>(dgamma[j]);
  if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast<InputType>(dbeta[j]);
}

} // namespace
} // namespace test