delayed_scaling.cu 15.8 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/recipe.h>

#include <cmath>
10
#include <limits>
11
#include <string>
12
13

#include "../common.h"
yuguo's avatar
yuguo committed
14
15
16
#ifdef __HIP_PLATFORM_AMD__
#include "../util/hip_runtime.h"
#else
17
#include "../util/cuda_runtime.h"
yuguo's avatar
yuguo committed
18
#endif
19
#include "../util/logging.h"
20
21
22
23
24
25
26
27
28
29
30

namespace transformer_engine {
namespace delayed_scaling_recipe {

namespace {

// amax value to use for updating scaling factor
enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX };

const char* dtype_name(DType dtype) {
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type,
31
                                     return TypeInfo<Type>::name;);  // NOLINT(*)
32
33
34
35
36
37
  return "";
}

// Maximum representable value of an FP8 dtype
inline float fp8_dtype_max(DType dtype) {
  switch (dtype) {
38
    case DType::kFloat8E4M3:
yuguo's avatar
yuguo committed
39
#ifndef __HIP_PLATFORM_AMD__
40
      return 448;
yuguo's avatar
yuguo committed
41
42
43
#else
      return 240;
#endif
44
45
46
47
    case DType::kFloat8E5M2:
      return 57344;
    default:
      NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype));
48
49
50
51
  }
  return 0;
}

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// struct for amax parameters
struct AmaxParam {
  int num_scale = 0;
  float* amax_history = nullptr;
  float* scale = nullptr;
};

// dummy struct for kernel_bulk's other params
struct OtherParams {
  float* a;
  size_t b;
  AmaxComputeAlgo c;
  float d;
};

#if CUDART_VERSION >= 12010
68
constexpr size_t max_constant_memory_per_kernel = 32768;
69
70
constexpr size_t AMAX_PARAMS_LIMIT =
    (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
71
#else
72
constexpr size_t max_constant_memory_per_kernel = 4096;
73
74
constexpr size_t AMAX_PARAMS_LIMIT =
    (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
75
76
77
78
79
80
#endif

struct AmaxParams {
  AmaxParam param[AMAX_PARAMS_LIMIT];
};

81
82
83
84
85
86
87
88
89
90
91
92
namespace amax_and_scale_update_impl {

// CUDA block size
constexpr size_t bsize = 256;

/* CUDA kernel to update amax history and FP8 scaling factors
 *
 * Block dims: bsize x 1 x 1
 *
 * Grid dims: num_scales x 1 x 1
 */
__global__ void __launch_bounds__(bsize)
93
94
95
    kernel(const float* amax_history_ptr, const float* scale_ptr, float* updated_amax_history_ptr,
           float* updated_scale_ptr, size_t amax_history_length, size_t amax_history_stride,
           AmaxComputeAlgo amax_compute_algo, float scaled_max) {
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  const size_t tid = threadIdx.x;
  const size_t bid = blockIdx.x;

  // Update amax
  float amax = 0;
  {
    // Roll amax history
    const auto* amax_history = amax_history_ptr + bid;
    auto* updated_amax_history = updated_amax_history_ptr + bid;
    const auto last_amax = amax_history[0];
    const auto& length = amax_history_length;
    const auto& stride = amax_history_stride;
    for (size_t off = 0; off < length; off += bsize) {
      const size_t i = off + tid;
      float a = 0;
      if (i < length) {
112
        a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
113
114
115
116
        amax = fmaxf(amax, a);
      }
      __syncthreads();  // In case roll is in-place
      if (i < length) {
117
        updated_amax_history[i * stride] = (i > 0) ? a : 0;
118
119
120
121
122
      }
    }

    // Compute amax to use for scaling factor
    switch (amax_compute_algo) {
123
124
125
126
      case AmaxComputeAlgo::MOST_RECENT:
        amax = last_amax;
        break;
      case AmaxComputeAlgo::MAX: {
127
128
129
130
131
132
133
134
135
136
137
        __shared__ float shared_amax[bsize];
        shared_amax[tid] = amax;
        __syncthreads();
#pragma unroll
        for (size_t off = bsize / 2; off > 0; off /= 2) {
          if (tid < off) {
            shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
          }
          __syncthreads();
        }
        amax = shared_amax[tid];
138
139
140
      } break;
      default:
        amax = 0;
141
142
143
    }
  }

144
  // Update scale
145
146
147
148
149
150
151
152
  if (tid == 0) {
    // Update scale
    float scale;
    if (isfinite(amax) && amax > 0) {
      scale = scaled_max / amax;
    } else {
      scale = scale_ptr[bid];
    }
153
154
155
156
157
    // When the amax is too tiny that the scale becoming infinite in FP32,
    // we set the scale to the max value of FP32. In this case, the tensor’s
    // amax won't get mapped to the FP8 max representable, but rather
    // something below that, but this is the best thing we can do.
    if (isinf(scale)) {
158
      scale = std::numeric_limits<float>::max();
159
    }
160
161
162
163
    updated_scale_ptr[bid] = scale;
  }
}

164
165
166
167
168
169
170
/* CUDA kernel to bulk-update amax history and FP8 scaling factors
 *
 * Block dims: bsize x 1 x 1
 *
 * Grid dims: num_tensors x 1 x 1
 */
__global__ void __launch_bounds__(bsize)
171
172
    kernel_bulk(float* amax_reduction_buffer, AmaxParams p, size_t amax_history_length,
                AmaxComputeAlgo amax_compute_algo, float scaled_max) {
173
174
175
176
177
178
179
180
  const size_t bid = blockIdx.x;
  const size_t tid = threadIdx.x;
  const int num_scale = p.param[bid].num_scale;

  int offset_in_buffer = 0;
  for (int j = 0; j < bid; j++) {
    offset_in_buffer += p.param[j].num_scale;
  }
181

182
183
184
185
186
187
188
  for (int count = 0; count < num_scale; count++) {
    // Update amax
    float amax = 0;
    {
      // Roll amax history
      const auto& length = amax_history_length;
      const auto& stride = p.param[bid].num_scale;
189
190
191
192
193
      auto* amax_history = p.param[bid].amax_history + count;
      const auto last_amax = ((amax_reduction_buffer != nullptr) &&
                              (amax_reduction_buffer[offset_in_buffer + count] != 0.0f))
                                 ? amax_reduction_buffer[offset_in_buffer + count]
                                 : amax_history[0];
194
195
196
197
198
      if (last_amax != 0.0f) {
        for (size_t off = 0; off < length; off += bsize) {
          const size_t i = off + tid;
          float a = 0;
          if (i < length) {
199
            a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
200
201
202
203
            amax = fmaxf(amax, a);
          }
          __syncthreads();  // Inplace roll
          if (i < length) {
204
            amax_history[i * stride] = (i > 0) ? a : 0;
205
          }
206
207
208
209
210
        }
      }

      // Compute amax to use for scaling factor
      switch (amax_compute_algo) {
211
212
213
214
        case AmaxComputeAlgo::MOST_RECENT:
          amax = last_amax;
          break;
        case AmaxComputeAlgo::MAX: {
215
216
217
218
219
220
221
222
223
224
225
          __shared__ float shared_amax[bsize];
          shared_amax[tid] = amax;
          __syncthreads();
#pragma unroll
          for (size_t off = bsize / 2; off > 0; off /= 2) {
            if (tid < off) {
              shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
            }
            __syncthreads();
          }
          amax = shared_amax[tid];
226
227
228
        } break;
        default:
          amax = 0;
229
230
231
      }
    }

232
    // Update scale
233
    if (tid == 0) {
234
235
236
237
238
239
240
241
242
243
244
      // Computing the scaling factor requires consideration of the following scenarios:
      // 1. amax == 0:
      //    No action is possible, set scale to the previous scale (or 1).
      // 2. 0 < amax < tiny_amax
      //    The amax is too tiny that the scale becomes infinite in FP32.
      //    Set scale = FP32_max
      // 3. tiny_amax <= amax < FP32_max:
      //    Set scale = FP8_max (or scaled_max) / amax
      // 4. When amax == inf or amax == nan:
      //    No action is possible, set scale to the previous scale (or 1).

245
246
247
248
249
250
      float scale;
      if (isfinite(amax) && amax > 0) {
        scale = scaled_max / amax;
      } else {
        scale = p.param[bid].scale[count];
      }
251
252
253
254
255
      // When the amax is too tiny that the scale becoming infinite in FP32,
      // we set the scale to the max value of FP32. In this case, the tensor’s
      // amax won't get mapped to the FP8 max representable, but rather
      // something below that, but this is the best thing we can do.
      if (isinf(scale)) {
256
        scale = std::numeric_limits<float>::max();
257
      }
258
259
260
261
262
263
      p.param[bid].scale[count] = scale;
    }
  }
}

}  // namespace amax_and_scale_update_impl
264
265
266

}  // namespace

267
268
void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale,
                           Tensor* updated_amax_history_, Tensor* updated_scale_,
269
                           const std::string& amax_compute_algo, DType fp8_dtype, float margin,
270
271
272
273
274
                           cudaStream_t stream) {
  auto& updated_amax_history = *updated_amax_history_;
  auto& updated_scale = *updated_scale_;

  // Check tensors
275
276
  NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(),
             " dims");
277
278
  const size_t amax_history_length = amax_history.data.shape[0];
  const size_t num_scales = amax_history.data.shape[1];
279
280
  NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ",
             dtype_name(amax_history.data.dtype), ".");
281
282
  NVTE_CHECK(scale.numel() == num_scales, "Expected ", num_scales, " elements, ", "but found ",
             scale.numel(), ".");
283
284
285
286
287
288
  NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), ".");
  NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ",
             updated_amax_history.data.shape.size(), " dims.");
  NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ",
             amax_history_length, ", ", "but found ", updated_amax_history.data.shape[0]);
  NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, "Expected ", num_scales, ", ",
289
             "but found ", updated_amax_history.data.shape[1]);
290
291
  NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ",
             dtype_name(updated_amax_history.data.dtype), ".");
292
293
  NVTE_CHECK(updated_scale.numel() == num_scales, "Expected ", num_scales, " elements, ",
             "but found ", updated_scale.numel(), ".");
294
295
  NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ",
             dtype_name(updated_scale.data.dtype), ".");
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312

  // amax value to use for updating scaling factor
  AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
  if (amax_compute_algo == "max") {
    amax_compute_algo_ = AmaxComputeAlgo::MAX;
  } else if (amax_compute_algo == "most_recent") {
    amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
  } else {
    NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
  }

  // Expected maximum value after scale is applied
  const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);

  // Launch CUDA kernel
  constexpr size_t block_size = amax_and_scale_update_impl::bsize;
  const size_t grid_size = num_scales;
313
314
  amax_and_scale_update_impl::kernel<<<grid_size, block_size, 0, stream>>>(
      static_cast<const float*>(amax_history.data.dptr), static_cast<const float*>(scale.data.dptr),
315
      static_cast<float*>(updated_amax_history.data.dptr),
316
      static_cast<float*>(updated_scale.data.dptr), amax_history_length, num_scales,
317
      amax_compute_algo_, scaled_max);
318
319
320
  NVTE_CHECK_CUDA(cudaGetLastError());
}

321
void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
322
323
                                           std::vector<Tensor*> amax_histories,
                                           std::vector<Tensor*> scales,
324
325
                                           const std::string& amax_compute_algo, DType fp8_dtype,
                                           float margin, cudaStream_t stream) {
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
  using namespace transformer_engine;

  // amax value to use for updating scaling factor
  AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
  if (amax_compute_algo == "max") {
    amax_compute_algo_ = AmaxComputeAlgo::MAX;
  } else if (amax_compute_algo == "most_recent") {
    amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
  } else {
    NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
  }

  // Expected maximum value after scale is applied
  const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);

  // Number of tensors in the bulk
  const size_t num_tensors = amax_histories.size();
343
  size_t num_remaining_tensors = num_tensors;
344
  const int num_kernels = (num_tensors + AMAX_PARAMS_LIMIT - 1) / AMAX_PARAMS_LIMIT;
345
346
347
348
349
350
351
352
353
354
  size_t amax_history_length = 0;
  if (num_tensors > 0) {
    amax_history_length = amax_histories[0]->data.shape[0];
  }

  // amax parameters
  float* amax_buffer = static_cast<float*>(amax_reduction_buffer.data.dptr);
  AmaxParams p;
  for (int iter = 0; iter < num_kernels; iter++) {
    size_t kernel_num_scales = 0;
355
356
    size_t kernel_num_tensors =
        (iter == (num_kernels - 1)) ? num_remaining_tensors : AMAX_PARAMS_LIMIT;
357
358
359
360
361
    for (size_t pi = 0; pi < kernel_num_tensors; pi++) {
      size_t i = iter * AMAX_PARAMS_LIMIT + pi;

      // Check tensors
      int num_scale = amax_histories[i]->data.shape[1];
362
363
364
365
      NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ",
                 dtype_name(amax_histories[i]->data.dtype), ".");
      NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ",
                 amax_histories[i]->data.shape.size(), " dims");
366
      NVTE_CHECK(amax_histories[i]->numel() == amax_history_length * num_scale, "Expected ",
367
                 amax_history_length * num_scale, " elements, ", "but found ",
368
                 amax_histories[i]->numel(), ".");
369
370
371
372
      NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ",
                 dtype_name(scales[i]->data.dtype), ".");
      NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(),
                 " dims");
373
374
      NVTE_CHECK(scales[i]->numel() == num_scale, "Expected ", num_scale, " elements, ", "Found ",
                 scales[i]->numel(), ".");
375
376
377
378
379
380
381
382
383
384
385

      // amax parameters
      kernel_num_scales += num_scale;
      p.param[pi].num_scale = num_scale;
      p.param[pi].amax_history = static_cast<float*>(amax_histories[i]->data.dptr);
      p.param[pi].scale = static_cast<float*>(scales[i]->data.dptr);
    }

    // Launch CUDA kernel
    size_t grid_size = kernel_num_tensors;
    const size_t block_size = amax_and_scale_update_impl::bsize;
386
387
    amax_and_scale_update_impl::kernel_bulk<<<grid_size, block_size, 0, stream>>>(
        amax_buffer, p, amax_history_length, amax_compute_algo_, scaled_max);
388
389
390
391
392
393
    NVTE_CHECK_CUDA(cudaGetLastError());

    // shift amax buffer pointer
    if (amax_buffer != nullptr) {
      amax_buffer += kernel_num_scales;
    }
394
    num_remaining_tensors -= AMAX_PARAMS_LIMIT;
395
396
397
  }
}

398
399
400
}  // namespace delayed_scaling_recipe
}  // namespace transformer_engine

401
void nvte_delayed_scaling_recipe_amax_and_scale_update(
402
403
    const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history,
    NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
404
    cudaStream_t stream) {
405
406
407
  NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
  using namespace transformer_engine;
  delayed_scaling_recipe::amax_and_scale_update(
408
409
      *reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
      reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
410
      amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
411
}
412
413

void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
414
    const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
415
416
    std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
    float margin, cudaStream_t stream) {
417
418
419
  NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
  using namespace transformer_engine;
  size_t num_tensors = amax_histories.size();
420
  std::vector<Tensor*> t_amax_histories, t_scales;
421
422
423
424
425
  for (size_t i = 0; i < num_tensors; i++) {
    t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
    t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
  }
  delayed_scaling_recipe::amax_and_scale_update_after_reduction(
426
      *reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
427
      amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
428
}