common.h 13.3 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_

yuguo's avatar
yuguo committed
10
#ifndef __HIP_PLATFORM_AMD__
11
12
13
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
yuguo's avatar
yuguo committed
14
#endif
15
#include <transformer_engine/normalization.h>
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <transformer_engine/transformer_engine.h>

#include <functional>
#include <map>
#include <stdexcept>
#include <tuple>
#include <typeindex>
#include <unordered_map>
#include <vector>

#include "../common.h"
#include "../cudnn_utils.h"
#include "../util/system.h"

namespace transformer_engine {

namespace normalization {

yuguo's avatar
yuguo committed
34
#ifndef __HIP_PLATFORM_AMD__
35
namespace fe = cudnn_frontend;
yuguo's avatar
yuguo committed
36
#endif
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

template <typename KernelParamsType>
struct LaunchParams {
  size_t workspace_bytes = 0;
  size_t barrier_bytes = 0;
  size_t dgamma_part_bytes = 0;
  int multiprocessorCount;
  cudaStream_t stream;

  KernelParamsType params;

  size_t getTotalWorkspaceBytes(const bool _is_layernorm = true) const {
    return (workspace_bytes + barrier_bytes + size_t(_is_layernorm + 1) * dgamma_part_bytes);
  }
  void alignWorkspace(size_t alignment = 16) {
    workspace_bytes = DIVUP(workspace_bytes, alignment) * alignment;
    barrier_bytes = DIVUP(barrier_bytes, alignment) * alignment;
    dgamma_part_bytes = DIVUP(dgamma_part_bytes, alignment) * alignment;
  }
};

struct KernelParamsBase {
  KernelParamsBase()
      : ctas_per_col(0),
        rows(0),
        cols(0),
        x(nullptr),
        mu(nullptr),
        rs(nullptr),
        gamma(nullptr),
        workspace(nullptr),
        barrier(nullptr),
        zero_centered_gamma(false) {}

  // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
  int ctas_per_col;
  // Size of CTA group.
  int ctas_per_row;

  // Input is interpreted as matrix. We normalize across columns.
  int rows;
  int cols;

  // Common data pointers.
  void* x;
  void* mu;
  void* rs;
  void* gamma;

  // Multi-CTA workspace in gmem.
  void* workspace;

  // Multi-CTA sync barriers in gmem.
  int* barrier;

  // Whether gamma is centered around 0
  bool zero_centered_gamma;
};

struct ForwardKernelParams : public KernelParamsBase {
  ForwardKernelParams()
      : KernelParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {}

  // Output of LN FWD.
  void* z;
  void* beta;
  float epsilon;

  // Scaling factor
  void* scale;
  int scale_byte_size;

  // Inverse of scaling factor
  void* scale_inv;

  // AMax output
  void* amax;
  int amax_byte_size;

  // Whether to compute scale and amax
  bool fp8_out;
};

struct BackwardKernelParams : public KernelParamsBase {
  BackwardKernelParams()
      : KernelParamsBase(),
        dz(nullptr),
        dbeta_part(nullptr),
        dgamma_part(nullptr),
        dx(nullptr),
        dbeta(nullptr),
        dgamma(nullptr) {}

  // Input: gradient wrt. LN FWD output.
  void* dz;

  // Workspace for Wgrad pre-reduction.
  void* dbeta_part;
  void* dgamma_part;

  // Output: Dgrad.
  void* dx;
  // Output: Wgrad.
  void* dbeta;
  void* dgamma;
};

enum class NVTE_Norm_Backend { Te, Cudnn };
enum class NVTE_Norm_Stage { Forward, Backward };

using TupleKeyType = std::tuple<uint64_t, uint64_t, uint64_t, bool>;
struct TupleHash {
  size_t operator()(const TupleKeyType& t) const {
    // Generate a hash for a tuple by combining the hashes of its entries
    // See: https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine
    size_t seed = 0;
    std::hash<uint64_t> hasher;
    seed ^= hasher(std::get<0>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
    seed ^= hasher(std::get<1>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
    seed ^= hasher(std::get<2>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
    return seed;
  }
};

161
162
163
164
165
166
// Note: the default mode here should match with the default mode with QTensor
TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType,
                     NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype,
                     uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma,
                     bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING,
                     bool training = true);
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

template <typename KernelParamsType>
class TeNormalizationRegistry {
 private:
  using Function = std::function<void(LaunchParams<KernelParamsType>&, const bool)>;
  std::unordered_map<TupleKeyType, Function, TupleHash> tuned_function_map;
  std::unordered_map<uint64_t, std::map<uint64_t, Function>> general_function_map;

  TeNormalizationRegistry() = default;

  static TeNormalizationRegistry& getInstance() {
    static TeNormalizationRegistry registry;
    return registry;
  }

 public:
  static int registerFunction(TupleKeyType key,
                              void (*func)(LaunchParams<KernelParamsType>&, const bool)) {
    auto [general_key, batch_size, hidden_size, is_tuned] = key;
    if (is_tuned)
      getInstance().tuned_function_map.emplace(key, Function(func));
    else
      getInstance().general_function_map[general_key].emplace(hidden_size, Function(func));
    return 0;
  }

  static Function getKernel(TupleKeyType key) {
    auto& instance = getInstance();
    auto [general_key, batch_size, hidden_size, is_tuned] = key;
    if (is_tuned) {
      auto it = instance.tuned_function_map.find(key);
      if (it != instance.tuned_function_map.end()) return it->second;
    }
    if (instance.general_function_map.count(general_key) == 0) {
      NVTE_ERROR("Unavailable kernel for this normalization config.");
    }
    auto& general_func_map = instance.general_function_map.at(general_key);
    auto func_iter = general_func_map.lower_bound(hidden_size);
    if (func_iter == general_func_map.end()) {
      return general_func_map.rbegin()->second;  // Hidden size is too big, need to use multi-CTA
    } else {
      return func_iter->second;
    }
  }

  TeNormalizationRegistry(const TeNormalizationRegistry&) = delete;
  TeNormalizationRegistry& operator=(const TeNormalizationRegistry&) = delete;
  TeNormalizationRegistry(TeNormalizationRegistry&&) = delete;
  TeNormalizationRegistry& operator=(TeNormalizationRegistry&&) = delete;
};

class NormalizationPlanBase {
 public:
  virtual ~NormalizationPlanBase() = default;
  virtual std::vector<size_t> getWorkspaceShape() const = 0;

  virtual void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr,
                       void* eps_dptr, void* rsigma_dptr, void* workspace_dptr,
                       cudaStream_t stream) = 0;

  virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr,
                       void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr,
                       void* workspace_dptr, cudaStream_t stream) = 0;

 private:
  virtual void _build() = 0;
};

template <typename KernelParamsType>
class TeNormalizationPlan : public NormalizationPlanBase {
 public:
  TeNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype,
                      DType otype, DType ctype, const size_t batch_size, const size_t hidden_size,
                      const size_t sm_count, const bool zero_centered_gamma, const bool is_tuned);
  std::vector<size_t> getWorkspaceShape() const override;

  void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr,
               void* eps_dptr, void* rsigma_dptr, void* workspace_dptr,
               cudaStream_t stream) override;

  void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
               void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
               cudaStream_t stream) override;

 private:
  void _set_workspace();
  void _build();

  using KernelRegistry = TeNormalizationRegistry<KernelParamsType>;
  LaunchParams<KernelParamsType> _launch_params;
  std::function<void(LaunchParams<KernelParamsType>&, const bool)> _kernel;

  const bool _is_layernorm;
};

class CudnnNormalizationPlan : public NormalizationPlanBase {
 public:
  CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype,
                         DType itype, DType otype, DType ctype, const size_t batch_size,
                         const size_t hidden_size, const size_t sm_count,
267
268
                         const bool zero_centered_gamma, const NVTEScalingMode mode,
                         const bool training);
269
270
271
272
273
274
275
276
277
278
279
280
281

  std::vector<size_t> getWorkspaceShape() const override;

  void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr,
               void* eps_dptr, void* rsigma_dptr, void* workspace_dptr,
               cudaStream_t stream) override;

  void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr,
               void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr,
               cudaStream_t stream) override;

 private:
  void _build() override;
yuguo's avatar
yuguo committed
282
283
  
#ifndef __HIP_PLATFORM_AMD__
284
  const bool _zero_centered, _fp8_out;
285
286
287
  int _ndim_scale_block;
  const NVTE_Norm_Stage _norm_stage;
  const NVTE_Norm_Type _norm_type;
288
  std::unique_ptr<char[]> _scalar_dptr;
289
  std::unique_ptr<float> _one_dptr = std::make_unique<float>(1.0f);
290
291
  // FWD
  std::shared_ptr<fe::graph::Tensor_attributes> _x, _gamma_zero, _scalar_offset, _gamma, _beta,
292
293
294
295
      _eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8;
  // MX FWD
  std::shared_ptr<fe::graph::Tensor_attributes> _z_mx_row, _z_mx_col, _sf_row, _sf_col;
  const bool _training;
296
297
298
299
300
301
  // BWD
  std::shared_ptr<fe::graph::Tensor_attributes> _dz, _dx, _dgamma, _dbeta;

  fe::graph::Graph _graph;
  std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> _variant_pack;
  cudnnHandle_t _handle;
yuguo's avatar
yuguo committed
302
#endif
303
304
305
306
307
};

class NormalizationPlanRegistry {
 public:
  static NormalizationPlanRegistry& getInstance() {
308
    static thread_local NormalizationPlanRegistry instance;
309
310
311
    return instance;
  }

312
313
314
315
316
  NormalizationPlanBase* getNormalizationPlan(
      NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage,
      DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size,
      const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned,
      const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true);
317
318
319
320
321
322
323
324
325
326
327
328
329
330

 private:
  NormalizationPlanRegistry() {}
  NormalizationPlanRegistry(const NormalizationPlanRegistry&) = delete;
  NormalizationPlanRegistry& operator=(const NormalizationPlanRegistry&) = delete;

  std::unordered_map<TupleKeyType, std::unique_ptr<NormalizationPlanBase>, TupleHash>
      normalizationPlanMap;
};

using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;
yuguo's avatar
yuguo committed
331
using int8 = int8_t;
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

template <typename T>
struct TypeToDType;

template <>
struct TypeToDType<fp32> {
  static constexpr DType value = DType::kFloat32;
};
template <>
struct TypeToDType<fp16> {
  static constexpr DType value = DType::kFloat16;
};
template <>
struct TypeToDType<bf16> {
  static constexpr DType value = DType::kBFloat16;
};
template <>
struct TypeToDType<fp8e4m3> {
  static constexpr DType value = DType::kFloat8E4M3;
};
template <>
yuguo's avatar
yuguo committed
356
357
358
359
struct TypeToDType<int8> {
  static constexpr DType value = DType::kInt8;
};
template <>
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
struct TypeToDType<fp8e5m2> {
  static constexpr DType value = DType::kFloat8E5M2;
};
template <>
struct TypeToDType<int32> {
  static constexpr DType value = DType::kInt32;
};
template <>
struct TypeToDType<byte> {
  static constexpr DType value = DType::kByte;
};

#define IS_TUNED(x) (strcmp(#x, "tuned") == 0 ? 1 : 0)

// TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those
#define REGISTER_NORM_BASE(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE,                    \
                           CTYPE, FUNC_NAME)                                                                        \
  static int                                                                                                        \
      register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \
          TeNormalizationRegistry<NORM_STAGE##KernelParams>::registerFunction(                                      \
380
381
382
383
              (get_key(NVTE_Norm_Backend::Te, NVTE_Norm_Type::NORM_TYPE,                                            \
                       NVTE_Norm_Stage::NORM_STAGE, (TypeToDType<WTYPE>::value),                                    \
                       (TypeToDType<ITYPE>::value), (TypeToDType<OTYPE>::value),                                    \
                       (TypeToDType<CTYPE>::value), 0, HIDDEN_SIZE, 0, IS_TUNED(LAUNCH_TYPE))),                     \
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
              FUNC_NAME)

// Alignment check
template <size_t Alignment = 16, typename... Args>
bool is_ptr_aligned(const Args*... ptrs) {
  return ((reinterpret_cast<uintptr_t>(ptrs) % Alignment == 0) && ...);
}

bool use_cudnn_norm_fwd();
bool use_cudnn_norm_bwd();

}  // namespace normalization
}  // namespace transformer_engine

#endif