common.cpp 23.3 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
 *
 * See LICENSE for license information.
 ************************************************************************/

/* #include <transformer_engine/layer_norm.h> */

#include "common.h"

#include <bitset>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <numeric>

#include "transformer_engine/normalization.h"
18
#include "transformer_engine/transformer_engine.h"
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

/*

Supported Type combinations:

input    compute   weights   output
=======================================
fp32     fp32      fp32      fp32
fp16     fp32      fp16      fp16
bf16     fp32      bf16      bf16
fp32     fp32      fp16      fp16
fp32     fp32      bf16      bf16
bf16     fp32      bf16      fp8

Remarks:
Output type = Weight type
Compute always in FP32

*/

namespace transformer_engine {
namespace normalization {

42
43
44
45
46
47
48
49
50
51
cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) {
  return training ? cudnn_frontend::NormFwdPhase_t::TRAINING
                  : cudnn_frontend::NormFwdPhase_t::INFERENCE;
}

TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
                     NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
                     uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
                     bool is_tuned, NVTEScalingMode mode, bool training) {
  // TODO: Add scaling_mode to general_key is needed
52
53
54
  uint64_t general_key = static_cast<uint32_t>(itype) | (static_cast<uint32_t>(otype) << 3) |
                         (static_cast<uint32_t>(ctype) << 6) | (static_cast<uint32_t>(wtype) << 9) |
                         (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 |
55
56
                         (uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) |
                         (uint32_t(mode) << 19) | (uint32_t(training) << 22);
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  return std::make_tuple(general_key, batch_size, hidden_size, is_tuned);
}

template <typename KernelParamsType>
TeNormalizationPlan<KernelParamsType>::TeNormalizationPlan(
    NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype,
    DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count,
    const bool zero_centered_gamma, const bool is_tuned)
    : _is_layernorm(NormType == NVTE_Norm_Type::LayerNorm) {
  _launch_params.multiprocessorCount = sm_count;

  auto& kernel_params = _launch_params.params;
  kernel_params.rows = batch_size;
  kernel_params.cols = hidden_size;
  kernel_params.zero_centered_gamma = zero_centered_gamma;
  if constexpr (std::is_same_v<KernelParamsType, ForwardKernelParams>) {
    kernel_params.fp8_out = is_fp8_dtype(otype);
  }
  // TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those
76
77
  auto key = get_key(NVTE_Norm_Backend::Te, NormType, NormStage, wtype, itype, otype, ctype, 0,
                     hidden_size, false, is_tuned);
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
  _kernel = KernelRegistry::getKernel(key);

  this->_build();
}

template <>
void TeNormalizationPlan<ForwardKernelParams>::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
                                                       void* beta_dptr, void* mean_dptr,
                                                       void* eps_dptr, void* rsigma_dptr,
                                                       void* workspace_dptr, cudaStream_t stream) {
  _launch_params.stream = stream;

  auto& kernel_params = _launch_params.params;
  kernel_params.workspace = workspace_dptr;
  kernel_params.x = x_dptr;
  kernel_params.rs = rsigma_dptr;
  kernel_params.gamma = gamma_dptr;
  kernel_params.z = z->data.dptr;
  kernel_params.epsilon = *reinterpret_cast<float*>(eps_dptr);
  kernel_params.amax = z->amax.dptr;
  kernel_params.scale = z->scale.dptr;
  kernel_params.scale_inv = z->scale_inv.dptr;

  if (_is_layernorm) {
    kernel_params.mu = mean_dptr;
    kernel_params.beta = beta_dptr;
  }

  _set_workspace();
  _kernel(_launch_params, false);
}

template <>
void TeNormalizationPlan<BackwardKernelParams>::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
                                                        void* beta_dptr, void* mean_dptr,
                                                        void* eps_dptr, void* rsigma_dptr,
                                                        void* workspace_dptr, cudaStream_t stream) {
  NVTE_ERROR("Backward normalization should not call the forward execute function!");
}

template <typename KernelParamsType>
void TeNormalizationPlan<KernelParamsType>::_build() {
  _kernel(_launch_params, true);
  _launch_params.alignWorkspace();
}

template <typename KernelParamsType>
std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const {
  return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)};
}

template <typename KernelParamsType>
void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
  if (_launch_params.getTotalWorkspaceBytes() > 0) {
    auto workspace_dptr = reinterpret_cast<byte*>(_launch_params.params.workspace);

    if (_launch_params.barrier_bytes > 0) {
      _launch_params.params.barrier =
          reinterpret_cast<int*>(workspace_dptr + _launch_params.workspace_bytes);
      cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes,
                      _launch_params.stream);
    }
    if constexpr (std::is_same_v<KernelParamsType, BackwardKernelParams>) {
      _launch_params.params.dgamma_part =
          workspace_dptr + _launch_params.workspace_bytes + _launch_params.barrier_bytes;
      if (_is_layernorm) {
        _launch_params.params.dbeta_part =
            reinterpret_cast<byte*>(_launch_params.params.dgamma_part) +
            _launch_params.dgamma_part_bytes;
      }
    }
  }
}

template <>
void TeNormalizationPlan<ForwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
                                                       void* mean_dptr, void* rsigma_dptr,
                                                       void* dx_dptr, void* dz_dptr,
                                                       void* dbeta_dptr, void* dgamma_dptr,
                                                       void* workspace_dptr, cudaStream_t stream) {
  NVTE_ERROR("Forward normalization should not call the backward execute function!");
}

template <>
void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamma_dptr,
                                                        void* mean_dptr, void* rsigma_dptr,
                                                        void* dx_dptr, void* dz_dptr,
                                                        void* dbeta_dptr, void* dgamma_dptr,
                                                        void* workspace_dptr, cudaStream_t stream) {
  _launch_params.stream = stream;

  auto& kernel_params = _launch_params.params;
  kernel_params.workspace = workspace_dptr;
  kernel_params.x = x_dptr;
  kernel_params.gamma = gamma_dptr;
  kernel_params.rs = rsigma_dptr;
  kernel_params.dx = dx_dptr;
  kernel_params.dz = dz_dptr;
  kernel_params.dgamma = dgamma_dptr;

  if (_is_layernorm) {
    kernel_params.mu = mean_dptr;
    kernel_params.dbeta = dbeta_dptr;
  }

  _set_workspace();
  _kernel(_launch_params, false);
}

CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
                                               DType wtype, DType itype, DType otype, DType ctype,
                                               const size_t batch_size, const size_t hidden_size,
                                               const size_t sm_count,
191
192
193
194
195
196
197
                                               const bool zero_centered_gamma,
                                               const NVTEScalingMode mode, bool training)
    : _fp8_out(is_fp8_dtype(otype)),
      _zero_centered(zero_centered_gamma),
      _training(training),
      _norm_stage(NormStage),
      _norm_type(NormType) {
198
199
200
201
202
  static_assert(CUDNN_FRONTEND_VERSION >= 10601,
                "CUDNN_FRONTEND_VERSION should be at least 1.6.1!");

  namespace fe = cudnn_frontend;

203
204
205
206
207
208
209
  if (is_tensor_scaling(mode)) {
    _ndim_scale_block = 0;
  } else {
    NVTE_CHECK(mode == NVTE_MXFP8_1D_SCALING, "Unsupported scaling mode.");
    _ndim_scale_block = 1;
  }

210
211
212
213
  _scalar_dptr = std::make_unique<char[]>(typeToSize(wtype));
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      wtype, cpp_dtype, *(reinterpret_cast<cpp_dtype*>(_scalar_dptr.get())) = (cpp_dtype)1.0f;);

214
  _handle = cudnnExecutionPlanManager::Instance().GetHandle();
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

  _graph.set_io_data_type(get_cudnn_fe_dtype(itype))
      .set_intermediate_data_type(get_cudnn_fe_dtype(ctype))
      .set_compute_data_type(get_cudnn_fe_dtype(ctype));

  if (cudnnGetVersion() >= 90400) _graph.set_sm_count(sm_count);

  const auto batch_dim = static_cast<int32_t>(batch_size);
  const auto hidden_dim = static_cast<int32_t>(hidden_size);

  // Create graph tensors
  _x = _graph.tensor(fe::graph::Tensor_attributes()
                         .set_name("X")
                         .set_dim({batch_dim, hidden_dim, 1, 1})
                         .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
                         .set_data_type(get_cudnn_fe_dtype(itype)));

  _gamma_zero = _graph.tensor(fe::graph::Tensor_attributes()
                                  .set_name("gamma_zero")
                                  .set_dim({1, hidden_dim, 1, 1})
                                  .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
                                  .set_data_type(get_cudnn_fe_dtype(wtype)));
237
  if (_zero_centered) {
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    _scalar_offset = _graph.tensor(fe::graph::Tensor_attributes()
                                       .set_name("one")
                                       .set_dim({1, 1, 1, 1})
                                       .set_stride({1, 1, 1, 1})
                                       .set_data_type(get_cudnn_fe_dtype(wtype))
                                       .set_is_pass_by_value(true));
    auto centered_options = fe::graph::Pointwise_attributes()
                                .set_mode(fe::PointwiseMode_t::ADD)
                                .set_compute_data_type(get_cudnn_fe_dtype(ctype));
    _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options);
    _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype));
  } else {
    _gamma = _gamma_zero;
  }

  // Create graph computation nodes
254
  if (_norm_stage == NVTE_Norm_Stage::Forward) {
255
256
257
258
259
260
    _eps = _graph.tensor(fe::graph::Tensor_attributes()
                             .set_name("epsilon")
                             .set_dim({1, 1, 1, 1})
                             .set_stride({1, 1, 1, 1})
                             .set_data_type(get_cudnn_fe_dtype(ctype))
                             .set_is_pass_by_value(true));
261
    if (_norm_type == NVTE_Norm_Type::LayerNorm) {
262
263
264
265
266
267
      _beta = _graph.tensor(fe::graph::Tensor_attributes()
                                .set_name("bias")
                                .set_dim({1, hidden_dim, 1, 1})
                                .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
                                .set_data_type(get_cudnn_fe_dtype(wtype)));
      auto norm_options = fe::graph::Layernorm_attributes()
268
                              .set_forward_phase(get_cudnn_forward_phase(_training))
269
270
271
272
                              .set_epsilon(_eps)
                              .set_compute_data_type(get_cudnn_fe_dtype(ctype));
      auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options);
      std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]);
273
274
      if (_training) _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
    } else {
275
      auto norm_options = fe::graph::Rmsnorm_attributes()
276
                              .set_forward_phase(get_cudnn_forward_phase(_training))
277
278
279
280
281
282
                              .set_epsilon(_eps)
                              .set_compute_data_type(get_cudnn_fe_dtype(ctype));
      auto ret = _graph.rmsnorm(_x, _gamma, norm_options);
      std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]);
    }

283
    if (_training) _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
284
285
286
287
288

    const auto ZDtype = _fp8_out ? ctype : otype;
    _z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype));

    if (_fp8_out) {
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
      if (_ndim_scale_block == 0) {  // tensor_scaling
        // create a scale node
        _z_scale = _graph.tensor(fe::graph::Tensor_attributes()
                                     .set_name("z_scale")
                                     .set_dim({1, 1, 1, 1})
                                     .set_stride({1, 1, 1, 1})
                                     .set_data_type(get_cudnn_fe_dtype(ctype)));
        auto z_scale_options = fe::graph::Pointwise_attributes()
                                   .set_mode(fe::PointwiseMode_t::MUL)
                                   .set_compute_data_type(get_cudnn_fe_dtype(ctype));
        _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options);

        _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));

        // create an amax reduction node
        _amax = _graph.reduction(_z, fe::graph::Reduction_attributes()
                                         .set_mode(fe::ReductionMode_t::AMAX)
                                         .set_compute_data_type(get_cudnn_fe_dtype(ctype)));
        _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1});
        _one_for_div = _graph.tensor(fe::graph::Tensor_attributes()
                                         .set_name("one_for_div")
                                         .set_dim({1, 1, 1, 1})
                                         .set_stride({1, 1, 1, 1})
                                         .set_data_type(get_cudnn_fe_dtype(ctype))
                                         .set_is_pass_by_value(true));
        auto div_options = fe::graph::Pointwise_attributes()
                               .set_mode(fe::PointwiseMode_t::DIV)
                               .set_compute_data_type(get_cudnn_fe_dtype(ctype));
        _z_scale_inv = _graph.pointwise(_one_for_div, _z_scale, div_options);
        _z_scale_inv->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype));
      } else if (_ndim_scale_block == 1) {  // 1d block scaling
        auto z_2d = _graph.reshape(_z, fe::graph::Reshape_attributes());
        z_2d->set_dim({batch_dim, hidden_dim});

        auto mx_quantize_row_opts = fe::graph::Block_scale_quantize_attributes()
                                        .set_block_size(32)
                                        .set_axis(1)
                                        .set_transpose(false);
        auto bs_row_ret = _graph.block_scale_quantize(z_2d, mx_quantize_row_opts);
        std::tie(_z_mx_row, _sf_row) = std::make_tuple(bs_row_ret[0], bs_row_ret[1]);
        _z_mx_row->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
        _sf_row->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0);  //TODO

        if (_training) {
          auto mx_quantize_col_opts = fe::graph::Block_scale_quantize_attributes()
                                          .set_block_size(32)
                                          .set_axis(0)
                                          .set_transpose(false);
          auto bs_col_ret = _graph.block_scale_quantize(z_2d, mx_quantize_col_opts);
          std::tie(_z_mx_col, _sf_col) = std::make_tuple(bs_col_ret[0], bs_col_ret[1]);
          _z_mx_col->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
          _sf_col->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0);
        }
      } else {
        NVTE_ERROR("Unsupported scaling mode.");
      }
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
    }
  } else {
    _dz = _graph.tensor(fe::graph::Tensor_attributes()
                            .set_name("dz")
                            .set_dim({batch_dim, hidden_dim, 1, 1})
                            .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}));
    _rsigma = _graph.tensor(fe::graph::Tensor_attributes()
                                .set_name("inv_var")
                                .set_dim({batch_dim, 1, 1, 1})
                                .set_stride({1, 1, 1, 1})
                                .set_data_type(get_cudnn_fe_dtype(ctype)));
    _mean = _graph.tensor(fe::graph::Tensor_attributes()
                              .set_name("mean")
                              .set_dim({batch_dim, 1, 1, 1})
                              .set_stride({1, 1, 1, 1})
                              .set_data_type(get_cudnn_fe_dtype(ctype)));
361
    if (_norm_type == NVTE_Norm_Type::LayerNorm) {
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
      auto norm_options = fe::graph::Layernorm_backward_attributes()
                              .set_saved_mean_and_inv_variance(_mean, _rsigma)
                              .set_compute_data_type(get_cudnn_fe_dtype(ctype));
      auto ret = _graph.layernorm_backward(_dz, _x, _gamma, norm_options);
      std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]);
      _dbeta->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
    } else {
      auto norm_options =
          fe::graph::Rmsnorm_backward_attributes().has_dbias(false).set_compute_data_type(
              get_cudnn_fe_dtype(ctype));
      auto ret = _graph.rmsnorm_backward(_dz, _x, _gamma, _rsigma, norm_options);
      std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]);
      if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned.");
    }
    _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
    _dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
  }
  // Build the graph
  this->_build();
}

void CudnnNormalizationPlan::_build() {
  NVTE_CHECK(_graph.validate().is_good());
  NVTE_CHECK(_graph.build_operation_graph(_handle).is_good());
  NVTE_CHECK(_graph
                 .create_execution_plans(
                     {cudnn_frontend::HeurMode_t::A, cudnn_frontend::HeurMode_t::FALLBACK})
                 .is_good());
  NVTE_CHECK(_graph.check_support(_handle).is_good());
  NVTE_CHECK(
      _graph.build_plans(_handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good());
}

std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
  return {static_cast<size_t>(_graph.get_workspace_size())};
}

void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr,
                                     void* mean_dptr, void* eps_dptr, void* rsigma_dptr,
                                     void* workspace_dptr, cudaStream_t stream) {
  // Binding data pointers to graph tensors
403
  _variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}};
404

405
406
407
408
409
410
  if (_training) _variant_pack.insert({{_rsigma, rsigma_dptr}});

  if (_norm_type == NVTE_Norm_Type::LayerNorm) {
    _variant_pack.insert({{_beta, beta_dptr}});
    if (_training) _variant_pack.insert({{_mean, mean_dptr}});
  }
411
412
413
414
415
416
417

  if (_zero_centered)
    _variant_pack.insert(
        {{_scalar_offset, reinterpret_cast<void*>(_scalar_dptr.get())}, {_gamma_zero, gamma_dptr}});
  else
    _variant_pack.insert({{_gamma, gamma_dptr}});

418
419
420
421
422
423
424
425
426
427
428
429
  if (_fp8_out && _ndim_scale_block == 0) {
    _variant_pack.insert({{_one_for_div, reinterpret_cast<void*>(_one_dptr.get())},
                          {_z_scale, z->scale.dptr},
                          {_z_scale_inv, z->scale_inv.dptr},
                          {_amax, z->amax.dptr},
                          {_z_fp8, z->data.dptr}});
  } else if (_fp8_out && _ndim_scale_block == 1) {
    _variant_pack.insert({{_z_mx_row, z->data.dptr}, {_sf_row, z->scale_inv.dptr}});
    if (_training)
      _variant_pack.insert(
          {{_z_mx_col, z->columnwise_data.dptr}, {_sf_col, z->columnwise_scale_inv.dptr}});
  } else {
430
    _variant_pack.insert({{_z, z->data.dptr}});
431
  }
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

  // Execute the computation
  NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
  NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
}

void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr,
                                     void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
                                     void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
                                     cudaStream_t stream) {
  // Binding data pointers to graph tensors
  _variant_pack = {
      {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};

  if (_zero_centered)
    _variant_pack.insert({{_scalar_offset, reinterpret_cast<void*>(this->_scalar_dptr.get())},
                          {_gamma_zero, gamma_dptr}});
  else
    _variant_pack.insert({{_gamma, gamma_dptr}});

  // layernorm should have valid mean_dptr and beta_dptr
  if (mean_dptr && dbeta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_dbeta, dbeta_dptr}});

  // Execute the computation
  NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream));
  NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good());
}

NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
    NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
    DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
463
464
    const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
    const NVTEScalingMode mode, const bool training) {
465
466
  const DType ctype = DType::kFloat32;
  bool is_tuned = is_aligned && (batch_size % 4 == 0);
467
468
  auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size,
                     hidden_size, zero_centered_gamma, is_tuned, mode, training);
469
470
471
472
473
474
475
476
477
478

  auto it = normalizationPlanMap.find(key);
  if (it != normalizationPlanMap.end()) {
    return it->second.get();
  }

  std::unique_ptr<NormalizationPlanBase> plan;
  if (NormBackend == NVTE_Norm_Backend::Cudnn) {
    plan = std::make_unique<CudnnNormalizationPlan>(NormType, NormStage, wtype, itype, otype, ctype,
                                                    batch_size, hidden_size, sm_count,
479
                                                    zero_centered_gamma, mode, training);
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
  } else if (NormStage == NVTE_Norm_Stage::Forward) {
    plan = std::make_unique<TeNormalizationPlan<ForwardKernelParams>>(
        NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
        zero_centered_gamma, is_tuned);
  } else {
    plan = std::make_unique<TeNormalizationPlan<BackwardKernelParams>>(
        NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
        zero_centered_gamma, is_tuned);
  }
  normalizationPlanMap.insert({key, std::move(plan)});
  return normalizationPlanMap[key].get();
}

bool& _cudnn_norm_fwd_flag() {
  static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN");
  return flag;
}

bool& _cudnn_norm_bwd_flag() {
  static bool flag = transformer_engine::getenv<bool>("NVTE_NORM_BWD_USE_CUDNN");
  return flag;
}

bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); }
bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); }

}  //  namespace normalization
}  // namespace transformer_engine

void nvte_enable_cudnn_norm_fwd(bool enable) {
  NVTE_API_CALL(nvte_enable_cudnn_norm_fwd);
  transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable;
}

void nvte_enable_cudnn_norm_bwd(bool enable) {
  NVTE_API_CALL(nvte_enable_cudnn_norm_bwd);
  transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable;
}